Exemplo n.º 1
0
    def __init__(self,
                 root_dir="data/ShapeNetCore.v2.PC15k",
                 categories=['airplane'],
                 tr_sample_size=10000,
                 te_sample_size=2048,
                 split='train',
                 scale=1.,
                 normalize_per_shape=False,
                 random_subsample=False):
        self.root_dir = root_dir
        self.split = split
        assert self.split in ['train', 'test', 'val']
        self.tr_sample_size = tr_sample_size
        self.te_sample_size = te_sample_size
        self.cates = categories
        self.cat = self.cates
        if 'all' in self.cates:
            self.synset_ids = list(cate_to_synsetid.values())
            self.cates = list(cate_to_synsetid.keys())
            self.cat = list(cate_to_synsetid.keys())
        else:
            self.synset_ids = [cate_to_synsetid[c] for c in self.cates]

        self.perCatValueMeter = {}
        for item in self.cat:
            self.perCatValueMeter[item] = AverageValueMeter()
        self.perCatValueMeter_metro = {}
        for item in self.cat:
            self.perCatValueMeter_metro[item] = AverageValueMeter()

        self.gravity_axis = 1
        self.display_axis_order = [0, 2, 1]

        super(ShapeNet15kPointClouds,
              self).__init__(root_dir,
                             self.synset_ids,
                             tr_sample_size=tr_sample_size,
                             te_sample_size=te_sample_size,
                             split=split,
                             scale=scale,
                             normalize_per_shape=normalize_per_shape,
                             random_subsample=random_subsample)
def val(model, dataloader, mse_loss, rate_loss, ps, epoch, show_inf_imgs=True):

    revert_transforms = transforms.Compose([
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        transforms.Lambda(lambda tensor: t.clamp(tensor, 0.0, 1.0)),
    ])

    if ps:
        ps.log('validating ... ')
    else:
        print('run val ... ')
    model.eval()
    # avg_loss = 0
    # progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), ascii=True)

    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    mse_loss_meter.reset()
    if opt.use_imp:
        rate_loss_meter.reset()
        rate_display_meter.reset()
        total_loss_meter.reset()

    for idx, (data, _) in enumerate(dataloader):
        val_input = Variable(data, volatile=True)
        if opt.use_gpu:
            val_input = val_input.cuda(async=True)

        reconstructed = model(val_input)

        batch_loss = mse_loss(reconstructed, val_input)
        batch_caffe_loss = batch_loss / (2 * opt.batch_size)

        if opt.use_imp and rate_loss:
            rate_loss_display = imp_mask_sigmoid
            rate_loss_value = rate_loss(rate_loss_display)
            total_loss = batch_caffe_loss + rate_loss_value

        mse_loss_meter.add(batch_caffe_loss.data[0])

        if opt.use_imp:
            rate_loss_meter.add(rate_loss_value.data[0])
            rate_display_meter.add(rate_loss_display.data.mean())
            total_loss_meter.add(total_loss.data[0])

        if show_inf_imgs:
            if idx > 5:
                continue
            reconstructed_imgs = revert_transforms(reconstructed.data[0])
            # pdb.set_trace()
            # print ('save imgs ...')

            dir_path = os.path.join(opt.save_inf_imgs_path, 'epoch_%d' % epoch)
            if not os.path.exists(dir_path):
                # print ('mkdir', dir_path)
                os.makedirs(dir_path)
            tv.utils.save_image(
                reconstructed_imgs,
                os.path.join(
                    dir_path, 'test_p{ps}_epoch{epoch}_idx{idx}.png'.format(
                        ps=opt.patch_size, epoch=epoch, idx=idx)))
            tv.utils.save_image(
                revert_transforms(val_input.data[0]),
                os.path.join(dir_path, 'origin_{idx}.png'.format(idx=idx)))

    if opt.use_imp:
        return mse_loss_meter.value()[0], rate_loss_meter.value(
        )[0], total_loss_meter.value()[0], rate_display_meter.value()[0]
    else:
        return mse_loss_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    # tpse: test patch size effect
    ps = PlotSaver(opt.exp_id[4:] + "_" + time.strftime("%m_%d__%H:%M:%S") +
                   ".log.txt")

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="TPSE_p{ps}_imp_r={r}_gama={w}".format(
            ps=opt.patch_size,
            r=opt.rate_loss_threshold,
            w=opt.rate_loss_weight)
        if opt.use_imp else "TPSE_p{ps}_no_imp".format(ps=opt.patch_size))

    if opt.use_gpu:
        model = multiple_gpu_process(model)

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_data_transforms = transforms.Compose([
        transforms.RandomCrop(opt.patch_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])

    val_data_transforms = transforms.Compose(
        [transforms.CenterCrop(256),
         transforms.ToTensor(), normalize])

    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)

    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)
    if opt.use_imp:
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="epoch",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss', 1, xlabel="epoch", ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="epoch",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter

        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps, -1, True)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps,
                                   -1, True)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            # make plot
            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss))

        # 开始训练
        model.train()

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):

            ipt = Variable(data)

            if opt.use_gpu:
                ipt = ipt.cuda(async=True)

            optimizer.zero_grad()

            reconstructed = model(ipt)

            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            # 相对于8x8的SIZE
            caffe_loss = caffe_loss / (
                (opt.patch_size // 8)**2
            )  # 8 16 32 ... (size = 4 * size')   * / 8 = 1 2 4 ...   log2(*/8) = 0, 1, 2 ...  4^(log2(*/8)) = (*/8)^2

            if opt.use_imp:
                rate_loss_display = imp_mask_sigmoid
                rate_loss_ = rate_loss(rate_loss_display)

                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()

            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                # print (rate_loss_display.data.mean())
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    # import pdb
                    pdb.set_trace()

        if use_data_parallel:
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # plot before val can ease me
        ps.make_plot('train mse loss'
                     )  # all epoch share a same img, so give "" to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps, epoch,
                epoch % opt.show_inf_imgs_T == 0)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps,
                               epoch, epoch % opt.show_inf_imgs_T == 0)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        ps.make_plot('val mse loss')

        if opt.use_imp:
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to', lr, 'at epoch',
                          epoch)
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Due to full epochs anneal lr to', lr, 'at epoch', epoch)
            ps.log('Due to full epochs anneal lr to %.10f at epoch %d' %
                   (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = total_loss_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # vis = Visualizer(opt.env)
    # log file
    ps = PlotSaver("Train_ImageNet12_With_ImpMap_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="CWCNN_limu_ImageNet_imp_r={r}_γ={w}".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else None)
    # if opt.use_imp else "test_pytorch")
    if opt.use_gpu:
        # model = multiple_gpu_process(model)
        model.cuda()


#    pdb.set_trace()

    cudnn.benchmark = True

    # step2: Data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.CenterCrop(224),
        # transforms.TenCrop(224),
        # transforms.Lambda(lambda crops: t.stack(([normalize(transforms.ToTensor()(crop)) for crop in crops]))),
        transforms.ToTensor(),
        normalize
    ])
    # train_data = ImageNet_200k(opt.train_data_root, train=True, transforms=data_transforms)
    # val_data = ImageNet_200k(opt.val_data_root, train = False, transforms=data_transforms)
    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    lr = opt.lr

    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if hasattr(model, 'module'):
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):
        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch

        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")

        model.train()

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):
            ipt = Variable(data)

            if opt.use_gpu:
                ipt = ipt.cuda()

            optimizer.zero_grad()  # f**k it! Don't forget to clear grad!
            reconstructed = model(ipt)

            # print ('reconstructed tensor size :', reconstructed.size())
            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                # print ('use data_parallel?',use_data_parallel)
                # pdb.set_trace()
                rate_loss_display = (model.module if use_data_parallel else
                                     model).imp_mask_sigmoid
                rate_loss_ = rate_loss(rate_loss_display)
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()

            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])
                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        # data parallel
        # if hasattr(model, 'module'):
        if use_data_parallel:
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # plot before val can ease me
        ps.make_plot(
            'train mse loss'
        )  # all epoch share a same img, so give ""(default) to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        ps.make_plot('val mse loss')

        if opt.use_imp:
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Anneal lr to', lr, 'at epoch', epoch,
                          'due to early stop.')
                    ps.log(
                        'Anneal lr to %.10f at epoch %d due to early stop.' %
                        (lr, epoch))

            else:
                tolerant_now -= 1

        if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Anneal lr to', lr, 'at epoch', epoch, 'due to full epochs.')
            ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' %
                   (lr, epoch))

        previous_loss = total_loss_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    logfile_name = "Cmpr_with_YOLOv2_" + opt.exp_desc + time.strftime(
        "_%m_%d_%H:%M:%S") + ".log.txt"
    ps = PlotSaver(logfile_name)

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        n=opt.feat_num,
        input_4_ch=opt.input_4_ch,
        model_name="Cmpr_yolo_imp_" + opt.exp_desc + "_r={r}_gama={w}".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else "Cmpr_yolo_no_imp_" + opt.exp_desc)
    # pdb.set_trace()
    if opt.use_gpu:
        model = multiple_gpu_process(model)

    cudnn.benchmark = True

    # step2: Data
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_data_transforms = transforms.Compose([
        # transforms.RandomHorizontalFlip(),  TODO: try to reimplement by myself to simultaneous operate on label and data
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose(
        [transforms.ToTensor(), normalize])
    train_data = ImageCropWithBBoxMaskDataset(
        opt.train_data_list,
        train_data_transforms,
        contrastive_degree=opt.contrastive_degree,
        mse_bbox_weight=opt.input_original_bbox_weight)
    val_data = ImageCropWithBBoxMaskDataset(
        opt.val_data_list,
        val_data_transforms,
        contrastive_degree=opt.contrastive_degree,
        mse_bbox_weight=opt.input_original_bbox_weight)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # TODO: new rate loss
        rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        # rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    def weighted_mse_loss(input, target, weight):
        # weight[weight!=opt.mse_bbox_weight] = 1
        # weight[weight==opt.mse_bbox_weight] = opt.mse_bbox_weight
        # print('max val', weight.max())
        # return mse_loss(input, target)
        # weight_clone = weight.clone()
        # weight_clone[weight_clone == opt.input_original_bbox_weight] = 0
        # return t.sum(weight_clone * (input - target) ** 2)
        weight_clone = t.ones_like(weight)
        weight_clone[weight ==
                     opt.input_original_bbox_inner] = opt.mse_bbox_weight
        return t.sum(weight_clone * (input - target)**2)

    def yolo_rate_loss(imp_map, mask_r):
        return rate_loss(imp_map)
        # V2 contrastive_degree must be 0!
        # return YoloRateLossV2(mask_r, opt.rate_loss_threshold, opt.rate_loss_weight)(imp_map)

    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0
    decay_file_create_time = -1  # 为了避免同一个文件反复衰减学习率, 所以判断修改时间

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch
        # vis.refresh_plot('cur epoch train mse loss')
        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")
        # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True)
        # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch))

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, weighted_mse_loss, yolo_rate_loss,
                    ps)
            else:
                mse_val_loss = val(model, val_dataloader, weighted_mse_loss,
                                   None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss))

        if opt.only_init_val:
            print('Only Init Val Over!')
            return

        model.train()

        if epoch == start_epoch + 1:
            print('Start training, please inspect log file %s!' % logfile_name)
        # mask is the detection bounding box mask
        for idx, (data, mask, o_mask) in enumerate(train_dataloader):

            # pdb.set_trace()

            data = Variable(data)
            mask = Variable(mask)
            o_mask = Variable(o_mask, requires_grad=False)

            if opt.use_gpu:
                data = data.cuda(async=True)
                mask = mask.cuda(async=True)
                o_mask = o_mask.cuda(async=True)

            # pdb.set_trace()

            optimizer.zero_grad()
            reconstructed, imp_mask_sigmoid = model(data, mask, o_mask)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))

            loss = weighted_mse_loss(reconstructed, data, o_mask)
            # loss = mse_loss(reconstructed, data)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                rate_loss_display = imp_mask_sigmoid
                # rate_loss_ =  rate_loss(rate_loss_display)
                rate_loss_ = yolo_rate_loss(rate_loss_display, mask)
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))

                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        if epoch % opt.save_interval == 0:
            print('save checkpoint file of epoch %d.' % epoch)
            if use_data_parallel:
                model.module.save(optimizer, epoch)
            else:
                model.save(optimizer, epoch)

        ps.make_plot('train mse loss')
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        if epoch % opt.eval_interval == 0:

            print('Validating ...')
            # val
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, weighted_mse_loss, yolo_rate_loss,
                    ps)
            else:
                mse_val_loss = val(model, val_dataloader, weighted_mse_loss,
                                   None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
    val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            train_rate_loss=rate_loss_meter.value()[0],
                            train_total_loss=total_loss_meter.value()[0],
                            train_rate_display=rate_display_meter.value()[0],
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to %.10f at epoch %d' %
                          (lr, epoch))
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        if epoch % opt.lr_anneal_epochs == 0:
            # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Anneal lr to %.10f at epoch %d due to full epochs.' %
                  (lr, epoch))
            ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' %
                   (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            cur_mtime = os.path.getmtime(opt.lr_decay_file)
            if cur_mtime > decay_file_create_time:
                decay_file_create_time = cur_mtime
                lr = lr * opt.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                print(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))
                ps.log(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))

        previous_loss = total_loss_meter.value()[0]
Exemplo n.º 6
0
def main_worker(save_dir, args):
    # basic setup
    cudnn.benchmark = True

    if args.log_name is not None:
        log_dir = "runs/%s" % args.log_name
    else:
        log_dir = f"runs/{datetime.datetime.now().strftime('%m-%d-%H-%M-%S')}"

    if args.local_rank == 0:
        logger = SummaryWriter(log_dir)
    else:
        logger = None

    deepspeed.init_distributed(dist_backend='nccl')
    torch.cuda.set_device(args.local_rank)

    model = SetVAE(args)
    parameters = model.parameters()

    n_parameters = sum(p.numel() for p in parameters if p.requires_grad)
    print(f'number of params: {n_parameters}')
    try:
        n_gen_parameters = sum(p.numel() for p in model.init_set.parameters() if p.requires_grad) + \
                           sum(p.numel() for p in model.pre_decoder.parameters() if p.requires_grad) + \
                           sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) + \
                           sum(p.numel() for p in model.post_decoder.parameters() if p.requires_grad) + \
                           sum(p.numel() for p in model.output.parameters() if p.requires_grad)
        print(f'number of generator params: {n_gen_parameters}')
    except AttributeError:
        pass

    optimizer, criterion = model.make_optimizer(args)

    # initialize datasets and loaders
    train_dataset, val_dataset, train_loader, val_loader = get_datasets(args)

    # initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        assert not (args.warmup_epochs > 0)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        assert not (args.warmup_epochs > 0)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=args.epochs // 2,
                                                    gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_w = min(1., ep /
                       args.warmup_epochs) if (args.warmup_epochs > 0) else 1.
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l * lr_w

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda_rule)
    elif args.scheduler == 'cosine':
        assert not (args.warmup_epochs > 0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs)
    else:
        # Fake SCHEDULER
        def lambda_rule(ep):
            return 1.0

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda_rule)

    # extract collate_fn
    if args.distributed:
        collate_fn = deepcopy(train_loader.collate_fn)
        model, optimizer, train_loader, scheduler = deepspeed.initialize(
            args=args,
            model=model,
            optimizer=optimizer,
            model_parameters=parameters,
            training_data=train_dataset,
            collate_fn=collate_fn,
            lr_scheduler=scheduler)

    # resume checkpoints
    start_epoch = 0
    if args.resume_checkpoint is None and Path(
            Path(save_dir) / 'checkpoint-latest.pt').exists():
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
        print('Resumed from: ' + args.resume_checkpoint)
    if args.resume_checkpoint is not None:
        if args.distributed:
            if args.resume_optimizer:
                model.module, model.optimizer, model.lr_scheduler, start_epoch = resume(
                    args.resume_checkpoint,
                    model.module,
                    model.optimizer,
                    scheduler=model.lr_scheduler,
                    strict=(not args.resume_non_strict))
            else:
                model.module, _, _, start_epoch = resume(
                    args.resume_checkpoint,
                    model.module,
                    optimizer=None,
                    strict=(not args.resume_non_strict))
        else:
            if args.resume_optimizer:
                model, optimizer, scheduler, start_epoch = resume(
                    args.resume_checkpoint,
                    model,
                    optimizer,
                    scheduler=scheduler,
                    strict=(not args.resume_non_strict))
            else:
                model, _, _, start_epoch = resume(
                    args.resume_checkpoint,
                    model,
                    optimizer=None,
                    strict=(not args.resume_non_strict))

    # save dataset statistics
    if args.local_rank == 0:
        train_dataset.save_statistics(save_dir)
        val_dataset.save_statistics(save_dir)

    # main training loop
    avg_meters = {
        'kl_avg_meter': AverageValueMeter(),
        'l2_avg_meter': AverageValueMeter()
    }

    assert args.distributed

    epoch = start_epoch
    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    for epoch in range(start_epoch, args.epochs):
        if args.local_rank == 0:
            # evaluate on the validation set
            if epoch % args.val_freq == 0 and epoch != 0:
                model.eval()
                with torch.no_grad():
                    val_res = validate(model.module, args, val_loader, epoch,
                                       logger, save_dir)
                    for k, v in val_res.items():
                        v = v.cpu().detach().item()
                        send_slack(f'{k}:{v}, Epoch {epoch - 1}')
                        if logger is not None and v is not None:
                            logger.add_scalar(f'val_sample/{k}', v, epoch - 1)

        # train for one epoch
        train_one_epoch(epoch, model, criterion, optimizer, args, train_loader,
                        avg_meters, logger)

        # Only on HEAD process
        if args.local_rank == 0:
            # save checkpoints
            if (epoch + 1) % args.save_freq == 0:
                if args.eval:
                    validate_reconstruct_l2(epoch, val_loader, model,
                                            criterion, args, logger)
                save(model.module, model.optimizer, model.lr_scheduler,
                     epoch + 1,
                     Path(save_dir) / f'checkpoint-{epoch}.pt')
                save(model.module, model.optimizer, model.lr_scheduler,
                     epoch + 1,
                     Path(save_dir) / 'checkpoint-latest.pt')

            # save visualizations
            if (epoch + 1) % args.viz_freq == 0:
                with torch.no_grad():
                    visualize(model.module, args, val_loader, epoch, logger)

        # adjust the learning rate
        model.lr_scheduler.step()
        if logger is not None and args.local_rank == 0:
            logger.add_scalar('train lr',
                              model.lr_scheduler.get_last_lr()[0], epoch)

    model.eval()
    if args.local_rank == 0:
        with torch.no_grad():
            val_res = validate(model.module, args, val_loader, epoch, logger,
                               save_dir)
            for k, v in val_res.items():
                v = v.cpu().detach().item()
                send_slack(f'{k}:{v}, Epoch {epoch}')
                if logger is not None and v is not None:
                    logger.add_scalar(f'val_sample/{k}', v, epoch)

    if logger is not None:
        logger.flush()
        logger.close()
Exemplo n.º 7
0
def main_worker(gpu, save_dir, args):
    # basic setup
    cudnn.benchmark = True
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    model = HyperRegression(args)

    torch.cuda.set_device(args.gpu)
    model = model.cuda(args.gpu)
    start_epoch = 0
    optimizer = model.make_optimizer(args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    # main training loop
    start_time = time.time()
    point_nats_avg_meter = AverageValueMeter()
    if args.distributed:
        print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))

    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    for epoch in range(start_epoch, args.epochs):
        print("Epoch starts:")
        data = ExampleData()
        train_loader = torch.utils.data.DataLoader(dataset=data,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=0,
                                                   pin_memory=True)
        for bidx, data in enumerate(train_loader):
            x, y = data
            x = x.float().to(args.gpu).unsqueeze(1)
            y = y.float().to(args.gpu).unsqueeze(1).unsqueeze(2)
            step = bidx + len(train_loader) * epoch
            model.train()
            recon_nats = model(x, y, optimizer, step, None)
            point_nats_avg_meter.update(recon_nats.item())
            if step % args.log_freq == 0:
                duration = time.time() - start_time
                start_time = time.time()
                print(
                    "[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] PointNats %2.5f"
                    % (args.rank, epoch, bidx, len(train_loader), duration,
                       point_nats_avg_meter.avg))
        # save visualizations
        kk = 3
        if (epoch + 1) % args.viz_freq == 0:
            # reconstructions
            model.eval()
            x = torch.from_numpy(np.linspace(0, kk, num=100)).float().to(
                args.gpu).unsqueeze(1)
            _, y = model.decode(x, 100)
            x = x.cpu().detach().numpy()
            y = y.cpu().detach().numpy()
            x = np.expand_dims(x, 1).repeat(100, axis=1).flatten()
            y = y.flatten()
            figs, axs = plt.subplots(1, 1, figsize=(12, 12))
            plt.xlim([0, kk])
            plt.ylim([-2, 2])
            plt.scatter(x, y)
            plt.savefig(
                os.path.join(
                    save_dir, 'images',
                    'tr_vis_sampled_epoch%d-gpu%s.png' % (epoch, args.gpu)))
            plt.clf()
        if (epoch + 1) % args.save_freq == 0:
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-latest.pt'))
def val(compression_model,
        c_resnet_51,
        dataloader,
        class_loss,
        rate_loss=None,
        ps=None):
    if ps:
        ps.log('validating ... ')
    else:
        print('run val ... ')
    compression_model.eval()
    c_resnet_51.eval()

    class_loss_meter = AverageValueMeter()
    # if opt.use_imp:
    top5_acc_meter = AverageValueMeter()
    top1_acc_meter = AverageValueMeter()
    # total_loss_meter = AverageValueMeter()

    class_loss_meter.reset()
    # if opt.use_imp:
    top5_acc_meter.reset()
    top1_acc_meter.reset()
    # total_loss_meter.reset()

    for idx, (data, label) in enumerate(dataloader):
        #ps.log('%.0f%%' % (idx*100.0/len(dataloader)))
        val_input = Variable(data, volatile=True)
        label = Variable(label, volatile=True)

        if opt.use_gpu:
            val_input = val_input.cuda()
            label = label.cuda()

        # compressed_RGB = compression_model(val_input)

        compressed_feat = compression_model(val_input, need_decode=False)
        predicted = c_resnet_51(compressed_feat)

        val_class_loss = class_loss(predicted, label)
        class_loss_meter.add(val_class_loss.data[0])
        acc1, acc5 = accuracy(predicted.data, label.data, topk=(1, 5))

        top5_acc_meter.add(acc5[0])
        top1_acc_meter.add(acc1[0])

    return class_loss_meter.value()[0], top5_acc_meter.value(
    )[0], top1_acc_meter.value()[0]
Exemplo n.º 9
0
network = network.cuda()  # move network to GPU
# If needed, load existing model
if opt.model != '':
    network.load_state_dict(torch.load(opt.model))
    print('Previous net weights loaded')
# ========================================================== #

# ===================CREATE optimizer and LOSSES================================= #
lrate = 0.001  # learning rate
optimizer = optim.Adam(network.parameters(), lr=lrate)
loss = torch.nn.L1Loss(reduction="mean")
# ========================================================== #

# =============DEFINE stuff for logs======================================== #
# meters to record stats on learning
train_loss = AverageValueMeter()
test_loss = AverageValueMeter()
best_train_loss = 10000.
with open(logfile, 'a') as f:  # open logfile and append network's architecture
    f.write(str(network) + '\n')
# ========================================================== #


# =============PROJECTION function======================================== #
def transformation(vertices, R, t):
    '''
    Calculate projective transformation of vertices given a projection matrix
    Input parameters:
    K: batch_size * 3 * 3 intrinsic camera matrix
    R, t: batch_size * 3 * 3, batch_size * 1 * 3 extrinsic calibration parameters
Exemplo n.º 10
0
def main_worker(gpu, save_dir, ngpus_per_node, args):
    # basic setup
    cudnn.benchmark = True
    normalize = False
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    model = HyperRegression(args)

    torch.cuda.set_device(args.gpu)
    model = model.cuda(args.gpu)
    start_epoch = 0
    optimizer = model.make_optimizer(args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    # initialize datasets and loaders

    # initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.epochs // 2,
                                              gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l

        scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                lr_lambda=lambda_rule)
    else:
        assert 0, "args.schedulers should be either 'exponential' or 'linear'"

    # main training loop
    start_time = time.time()
    entropy_avg_meter = AverageValueMeter()
    latent_nats_avg_meter = AverageValueMeter()
    point_nats_avg_meter = AverageValueMeter()
    if args.distributed:
        print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))

    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    data = SDDData(split='train', normalize=normalize, root=args.data_dir)
    data_test = SDDData(split='test', normalize=normalize, root=args.data_dir)
    train_loader = torch.utils.data.DataLoader(dataset=data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=0,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset=data_test,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True)
    for epoch in range(start_epoch, args.epochs):
        # adjust the learning rate
        if (epoch + 1) % args.exp_decay_freq == 0:
            scheduler.step(epoch=epoch)

        # train for one epoch
        print("Epoch starts:")
        for bidx, data in enumerate(train_loader):
            # if bidx < 2:
            x, y = data
            #y = y.float().to(args.gpu).unsqueeze(1).repeat(1, 10).unsqueeze(2)
            x = x.float().to(args.gpu)
            y = y.float().to(args.gpu).unsqueeze(1)
            y = y.repeat(1, 20, 1)
            y += torch.randn(y.shape[0], y.shape[1], y.shape[2]).to(args.gpu)
            step = bidx + len(train_loader) * epoch
            model.train()
            recon_nats = model(x, y, optimizer, step, None)
            point_nats_avg_meter.update(recon_nats.item())
            if step % args.log_freq == 0:
                duration = time.time() - start_time
                start_time = time.time()
                print(
                    "[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] PointNats %2.5f"
                    % (args.rank, epoch, bidx, len(train_loader), duration,
                       point_nats_avg_meter.avg))
                # print("Memory")
                # print(process.memory_info().rss / (1024.0 ** 3))
        # save visualizations
        if (epoch + 1) % args.viz_freq == 0:
            # reconstructions
            model.eval()
            for bidx, data in enumerate(test_loader):
                x, _ = data
                x = x.float().to(args.gpu)
                _, y_pred = model.decode(x, 100)
                y_pred = y_pred.cpu().detach().numpy().squeeze()
                # y_pred[y_pred < 0] = 0
                # y_pred[y_pred >= 0.98] = 0.98
                testing_sequence = data_test.dataset.scenes[
                    data_test.test_id].sequences[bidx]
                objects_list = []
                for k in range(3):
                    objects_list.append(
                        decode_obj(testing_sequence.objects[k],
                                   testing_sequence.id))
                objects = np.stack(objects_list, axis=0)
                gt_object = decode_obj(testing_sequence.objects[-1],
                                       testing_sequence.id)
                drawn_img_hyps = draw_hyps(testing_sequence.imgs[-1], y_pred,
                                           gt_object, objects, normalize)
                cv2.imwrite(
                    os.path.join(save_dir, 'images',
                                 str(bidx) + '-' + str(epoch) + '-hyps.jpg'),
                    drawn_img_hyps)
        if (epoch + 1) % args.save_freq == 0:
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
            save(model, optimizer, epoch + 1,
                 os.path.join(save_dir, 'checkpoint-latest.pt'))
def train(**kwargs):
    global batch_model_id
    opt.parse(kwargs)
    if opt.use_batch_process:
        max_bp_times = len(opt.exp_ids)
        if batch_model_id >= max_bp_times:
            print('Batch Processing Done!')
            return
        else:
            print('Cur Batch Processing ID is %d/%d.' %
                  (batch_model_id + 1, max_bp_times))
            opt.r = opt.r_s[batch_model_id]
            opt.exp_id = opt.exp_ids[batch_model_id]
            opt.exp_desc = opt.exp_desc_LUT[opt.exp_id]
            opt.plot_path = "plot/plot_%d" % opt.exp_id
            print('Cur Model(exp_%d) r = %f, desc = %s. ' %
                  (opt.exp_id, opt.r, opt.exp_desc))
    opt.make_new_dirs()
    # log file
    EvalVal = opt.only_init_val and opt.init_val and not opt.test_test
    EvalTest = opt.only_init_val and opt.init_val and opt.test_test
    EvalSuffix = ""
    if EvalVal:
        EvalSuffix = "_val"
    if EvalTest:
        EvalSuffix = "_test"
    logfile_name = opt.exp_desc + time.strftime(
        "_%m_%d_%H:%M:%S") + EvalSuffix + ".log.txt"

    ps = PlotSaver(logfile_name, log_to_stdout=opt.log_to_stdout)

    # step1: Model
    # model = getattr(models, opt.model)(use_imp = opt.use_imp, n = opt.feat_num, model_name=opt.exp_desc + ("_r={r}_gm={w}".format(
    #                                                             r=opt.rate_loss_threshold,
    #                                                             w=opt.rate_loss_weight)
    #                                                           if opt.use_imp else "_no_imp"))

    model = getattr(models, opt.model)(use_imp=opt.use_imp,
                                       n=opt.feat_num,
                                       model_name=opt.exp_desc)
    # print (model)
    # pdb.set_trace()
    global use_data_parallel
    if opt.use_gpu:
        model, use_data_parallel = multiple_gpu_process(model)

    # real use gpu or cpu
    opt.use_gpu = opt.use_gpu and use_data_parallel >= 0
    cudnn.benchmark = True

    # step2: Data
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_data_transforms = transforms.Compose([
        # transforms.RandomHorizontalFlip(),  TODO: try to reimplement by myself to simultaneous operate on label and data
        transforms.RandomCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose(
        [transforms.CenterCrop(128),
         transforms.ToTensor(), normalize])

    caffe_data_transforms = transforms.Compose([transforms.CenterCrop(128)])
    # transforms ||  data
    train_data = ImageFilelist(
        flist=opt.train_data_list,
        transform=train_data_transforms,
    )

    val_data = ImageFilelist(
        flist=opt.val_data_list,
        prefix=opt.val_data_prefix,
        transform=val_data_transforms,
    )

    val_data_caffe = ImageFilelist(flist=opt.val_data_list,
                                   transform=caffe_data_transforms)

    test_data_caffe = ImageFilelist(flist=opt.test_data_list,
                                    transform=caffe_data_transforms)

    if opt.make_caffe_data:
        save_caffe_data(test_data_caffe)
        print('Make caffe dataset over!')
        return

    # train_data = ImageCropWithBBoxMaskDataset(
    #     opt.train_data_list,
    #     train_data_transforms,
    #     contrastive_degree = opt.contrastive_degree,
    #     mse_bbox_weight = opt.input_original_bbox_weight
    # )
    # val_data = ImageCropWithBBoxMaskDataset(
    #     opt.val_data_list,
    #     val_data_transforms,
    #     contrastive_degree = opt.contrastive_degree,
    #     mse_bbox_weight = opt.input_original_bbox_weight
    # )
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # TODO: new rate loss
        rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        # rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    def weighted_mse_loss(input, target, weight):
        # weight[weight!=opt.mse_bbox_weight] = 1
        # weight[weight==opt.mse_bbox_weight] = opt.mse_bbox_weight
        # print('max val', weight.max())
        # return mse_loss(input, target)
        # weight_clone = weight.clone()
        # weight_clone[weight_clone == opt.input_original_bbox_weight] = 0
        # return t.sum(weight_clone * (input - target) ** 2)
        weight_clone = t.ones_like(weight)
        weight_clone[weight ==
                     opt.input_original_bbox_inner] = opt.mse_bbox_weight
        return t.sum(weight_clone * (input - target)**2)

    def yolo_rate_loss(imp_map, mask_r):
        return rate_loss(imp_map)
        # V2 contrastive_degree must be 0!
        # return YoloRateLossV2(mask_r, opt.rate_loss_threshold, opt.rate_loss_weight)(imp_map)

    start_epoch = 0
    decay_file_create_time = -1  # 为了避免同一个文件反复衰减学习率, 所以判断修改时间

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    if opt.resume:
        start_epoch = (model.module if use_data_parallel == 1 else model).load(
            None if opt.finetune else optimizer, opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training from epoch %d.' % start_epoch)
            same_lr_epoch = start_epoch % opt.lr_anneal_epochs
            decay_times = start_epoch // opt.lr_anneal_epochs
            lr = opt.lr * (opt.lr_decay**decay_times)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Decay lr %d times, now lr is %e.' % (decay_times, lr))

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    # 如果测试时是600,max_epoch也是600
    if opt.only_init_val and opt.max_epoch <= start_epoch:
        opt.max_epoch = start_epoch + 2

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch
        # vis.refresh_plot('cur epoch train mse loss')
        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")
        # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True)
        # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch))

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}'
                    .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss))

        if opt.only_init_val:
            print('Only Init Val Over!')
            return

        model.train()

        # if epoch == start_epoch + 1:
        #     print ('Start training, please inspect log file %s!' % logfile_name)
        # mask is the detection bounding box mask
        for idx, data in enumerate(train_dataloader):

            # pdb.set_trace()

            data = Variable(data)
            # mask = Variable(mask)
            # o_mask = Variable(o_mask, requires_grad=False)

            if opt.use_gpu:
                data = data.cuda(async=True)
                # mask = mask.cuda(async = True)
                # o_mask = o_mask.cuda(async = True)

            # pdb.set_trace()

            optimizer.zero_grad()
            if opt.use_imp:
                reconstructed, imp_mask_sigmoid = model(data)
            else:
                reconstructed = model(data)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))

            loss = mse_loss(reconstructed, data)
            # loss = mse_loss(reconstructed, data)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                rate_loss_ = rate_loss(imp_mask_sigmoid)
                # rate_loss_ = yolo_rate_loss(imp_mask_sigmoid, mask)
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            total_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(imp_mask_sigmoid.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else imp_mask_sigmoid.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))

                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        if epoch % opt.save_interval == 0:
            print('Save checkpoint file of epoch %d.' % epoch)
            if use_data_parallel == 1:
                model.module.save(optimizer, epoch)
            else:
                model.save(optimizer, epoch)

        ps.make_plot('train mse loss')
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        if epoch % opt.eval_interval == 0:
            print('Validating ...')
            # val
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
    val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            train_rate_loss=rate_loss_meter.value()[0],
                            train_total_loss=total_loss_meter.value()[0],
                            train_rate_display=rate_display_meter.value()[0],
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                    .format(epoch=epoch,
                            lr=lr,
                            train_mse_loss=mse_loss_meter.value()[0],
                            val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to %.10f at epoch %d' %
                          (lr, epoch))
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                # tolerant_now -= 1
                tolerant_now = 0

        if epoch % opt.lr_anneal_epochs == 0:
            # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
            same_lr_epoch = 0
            tolerant_now = 0
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Anneal lr to %.10f at epoch %d due to full epochs.' %
                  (lr, epoch))
            ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' %
                   (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            cur_mtime = os.path.getmtime(opt.lr_decay_file)
            if cur_mtime > decay_file_create_time:
                decay_file_create_time = cur_mtime
                lr = lr * opt.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                print(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))
                ps.log(
                    'Anneal lr to %.10f at epoch %d due to decay-file indicator.'
                    % (lr, epoch))

        previous_loss = total_loss_meter.value()[0]
    if opt.use_batch_process:
        batch_model_id += 1
        train(kwargs)
Exemplo n.º 12
0
network = network.cuda()  # move network to GPU
# If needed, load existing model
if opt.model != '':
  network.load_state_dict(torch.load(opt.model))
  print('Previous net weights loaded')
# ========================================================== #

# ===================CREATE optimizer================================= #
lrate = 0.001  # learning rate
optimizer = optim.Adam(network.parameters(), lr=lrate)
# ========================================================== #

# =============DEFINE stuff for logs======================================== #
# meters to record stats on learning
total_train_loss = AverageValueMeter()
chd_train_loss = AverageValueMeter()
occ_train_loss = AverageValueMeter()
test_loss = AverageValueMeter()
best_train_loss = 10
with open(logfile, 'a') as f:  # open logfile and append network's architecture
  f.write(str(network) + '\n')
# ========================================================== #

# =============FIRST TRAINING EPOCH: occupancy only======================================== #
network.train()
for i, data in enumerate(dataloader_small_bs, 0):
  optimizer.zero_grad()

  points, img, depth_maps, _, camRt, _, _, _ = data
Exemplo n.º 13
0
def imgrad_yx(img):
  N,C,_,_ = img.size()
  grad_y, grad_x = imgrad(img)
  return torch.cat((grad_y.view(N,C,-1), grad_x.view(N,C,-1)), dim=1)

rmse = RMSE()
depth_criterion = RMSE_log()
grad_criterion = GradLoss()
normal_criterion = NormalLoss()
eval_metric = RMSE_log()

# ========================================================== #

# =============DEFINE stuff for logs======================================== #
# meters to record stats on learning
train_total = AverageValueMeter()
train_logRMSE = AverageValueMeter()
train_grad = AverageValueMeter()
train_normal = AverageValueMeter()
test_logRMSE = AverageValueMeter()
test_RMSE = AverageValueMeter()
best_train_loss = 10000.
with open(logfile, 'a') as f:  # open logfile and append network's architecture
  f.write(str(network) + '\n')
# ========================================================== #

# ===================TRAINING LOOP================================= #
# constants for loss balancing
grad_factor = 10.
normal_factor = 1.
for epoch in range(opt.nepoch):
Exemplo n.º 14
0
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    ps = PlotSaver("Context_Baseline_p32_no_imp_plain_1_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="Context_Baseline_p32_imp_r={r}_gama={w}_plain".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else "ContextBaseNoImpP32_Plain")

    if opt.use_gpu:
        model = multiple_gpu_process(model)

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    data_transforms = transforms.Compose([transforms.ToTensor(), normalize])
    train_data = datasets.ImageFolder(opt.train_data_root, data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, data_transforms)
    # opt.batch_size  --> 1
    train_dataloader = DataLoader(train_data,
                                  1,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                1,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    lr = opt.lr

    optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch

        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            # make plot
            ps.make_plot('val mse loss')

            if opt.use_imp:
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(
                        epoch=epoch,
                        lr=lr,
                        # train_mse_loss = mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        model.train()

        def to_patches(x, patch_size):
            print(type(x))
            _, h, w = x.size()
            print('original h,w', h, w)
            dw = (patch_size - w % patch_size) if w % patch_size else 0
            dh = (patch_size - h % patch_size) if h % patch_size else 0

            x = F.pad(x, (dw // 2, dw - dw // 2, dh // 2, dh - dh // 2))

            _, h, w = x.size()
            print(h, w)
            pdb.set_trace()

            tv.utils.save_image(x.data, 'test_images/padded_original_img.png')

            num_patch_x = w // patch_size
            num_patch_y = h // patch_size
            print(num_patch_x, num_patch_y)
            patches = []
            for i in range(num_patch_y):
                for j in range(num_patch_x):
                    patch = x[:, i * patch_size:(i + 1) * patch_size,
                              j * patch_size:(j + 1) * patch_size]
                    patches.append(patch.contiguous())
                if (j + 1) * patch_size < w:
                    extra_patch = x[:, i * patch_size:(i + 1) * patch_size,
                                    (j + 1) * patch_size:w]
                    extra_patch

            return patches

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):

            if idx == 0:
                print('skip idx =', idx)
                continue
            # ipt = Variable(data[0])
            ipt = data[0]
            # if opt.use_gpu:
            #     # if not use_data_parallel:  # because ipt is also target, so we still need to transfer it to CUDA
            #         ipt = ipt.cuda(async = True)

            # ipt is a full image, so I need to split it into crops
            # so set batch_size = 1 for simplicity at first

            # pdb.set_trace()

            tv.utils.save_image(ipt, "test_imgs/original.png")
            patches = to_patches(ipt, opt.patch_size)

            for (i, p) in enumerate(patches):
                tv.utils.save_image(p, "test_imgs/%s.png" % i)

            pdb.set_trace()

            optimizer.zero_grad()
            # reconstructed, imp_mask_sigmoid = model(ipt)
            reconstructed = model(ipt)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))
            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                rate_loss_display = imp_mask_sigmoid
                rate_loss_ = rate_loss(rate_loss_display)

                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            # 1.
            total_loss.backward()
            # caffe_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                # print (rate_loss_display.data.mean())
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])
                # pdb.set_trace()
                # progress_bar.set_description('epoch %d/%d, loss = %.2f' % (epoch, opt.max_epoch, total_loss.data[0]))

                #  2.
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], lr))
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), mse_loss_meter.value()[0], lr))

                # ps.log('loss = %f' % caffe_loss.data[0])
                # print(total_loss.data[0])
                # input('waiting......')

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    # import pdb
                    pdb.set_trace()

        if use_data_parallel:
            # print (type(model.module))
            # print (model)
            # print (type(model))
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # print ('case error', total_loss.data[0])
        # print ('smoothed error', total_loss_meter.value()[0])

        # plot before val can ease me
        ps.make_plot('train mse loss'
                     )  # all epoch share a same img, so give "" to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        # make plot
        # ps.make_plot('train mse loss', "")   # all epoch share a same img, so give "" to epoch
        # ps.make_plot('cur epoch train mse loss',epoch)
        ps.make_plot('val mse loss')

        if opt.use_imp:
            # ps.make_plot("train rate value","")
            # ps.make_plot("train rate loss","")
            # ps.make_plot("train total loss","")
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to', lr, 'at epoch',
                          epoch)
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
        #     same_lr_epoch = 0
        #     tolerant_now = 0
        #     lr = lr * opt.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        #     print ('Due to full epochs anneal lr to',lr,'at epoch',epoch)
        #     ps.log ('Due to full epochs anneal lr to %.10f at epoch %d' % (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        # previous_loss = total_loss_meter.value()[0] if opt.use_imp else mse_loss_meter.value()[0]
        previous_loss = total_loss_meter.value()[0]
def train(**kwargs):
    opt.parse(kwargs)
    # vis = Visualizer(opt.env)
    # log file
    ps = PlotSaver("CLIC_finetune_With_TestSet_0.13_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model

    # CWCNN_limu_ImageNet_imp_
    model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="CLIC_imp_ft_testSet_2_r={r}_gama={w}".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else None)
    # if opt.use_imp else "test_pytorch")
    if opt.use_gpu:
        model = multiple_gpu_process(model)
        # model.cuda()  # because I want model.imp_mask_sigmoid

    # freeze the decoder network, and finetune the enc + imp with train + val set
    if opt.freeze_decoder:
        for param in model.module.decoder.parameters():
            # print (param.requires_grad)
            param.requires_grad = False

    cudnn.benchmark = True

    # step2: Data
    # normalize = transforms.Normalize(
    #                 mean=[0.485, 0.456, 0.406],
    #                 std=[0.229, 0.224, 0.225]
    #                 )

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    # train_data_transforms = transforms.Compose(
    #     [
    #         transforms.Resize(256),
    #         # transforms.RandomCrop(128),
    #         # transforms.Resize((128,128)),
    #         # transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         # transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    #         normalize
    #     ]
    # )
    # val_data_transforms = transforms.Compose(
    #     [
    #         # transforms.Resize(256),
    #         # transforms.CenterCrop(128),
    #         # transforms.Resize((128,128)),
    #         # transforms.TenCrop(224),
    #         # transforms.Lambda(lambda crops: t.stack(([normalize(transforms.ToTensor()(crop)) for crop in crops]))),
    #         transforms.ToTensor(),
    #         normalize
    #     ]
    # )

    train_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    val_data_transforms = transforms.Compose([
        transforms.Resize(256),
        #transforms.Scale(256),
        transforms.CenterCrop(224),
        # transforms.TenCrop(224),
        # transforms.Lambda(lambda crops: t.stack(([normalize(transforms.ToTensor()(crop)) for crop in crops]))),
        transforms.ToTensor(),
        normalize
    ])
    # train_data = ImageNet_200k(opt.train_data_root, train=True, transforms=data_transforms)
    # val_data = ImageNet_200k(opt.val_data_root, train = False, transforms=data_transforms)
    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    # train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)

    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)
    # val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=False, num_workers=opt.num_workers)

    # step3: criterion and optimizer

    mse_loss = t.nn.MSELoss(size_average=False)

    if opt.use_imp:
        # rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)
        rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight)

    lr = opt.lr

    # print ('model.parameters():')
    # print (model.parameters())

    # optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
    optimizer = t.optim.Adam(filter(lambda p: p.requires_grad,
                                    model.parameters()),
                             lr=lr,
                             betas=(0.9, 0.999))

    start_epoch = 0

    if opt.resume:
        if use_data_parallel:
            start_epoch = model.module.load(
                None if opt.finetune else optimizer, opt.resume, opt.finetune)
        else:
            start_epoch = model.load(None if opt.finetune else optimizer,
                                     opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    previous_loss = 1e100
    tolerant_now = 0
    same_lr_epoch = 0

    # ps init

    ps.new_plot('train mse loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_mse_loss")
    ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss")
    if opt.use_imp:
        ps.new_plot('train rate value',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_value")
        ps.new_plot('train rate loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_rate_loss")
        ps.new_plot('train total loss',
                    opt.print_freq,
                    xlabel="iteration",
                    ylabel="train_total_loss")
        ps.new_plot('val rate value',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_value")
        ps.new_plot('val rate loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_rate_loss")
        ps.new_plot('val total loss',
                    1,
                    xlabel="iteration",
                    ylabel="val_total_loss")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):

        same_lr_epoch += 1
        # per epoch avg loss meter
        mse_loss_meter.reset()
        if opt.use_imp:
            rate_display_meter.reset()
            rate_loss_meter.reset()
            total_loss_meter.reset()
        else:
            total_loss_meter = mse_loss_meter
        # cur_epoch_loss refresh every epoch
        # vis.refresh_plot('cur epoch train mse loss')
        ps.new_plot("cur epoch train mse loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="train_mse_loss")
        # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True)
        # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch))

        # Init val
        if (epoch == start_epoch + 1) and opt.init_val:
            print('Init validation ... ')
            if opt.use_imp:
                mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                    model, val_dataloader, mse_loss, rate_loss, ps)
            else:
                mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

            ps.add_point('val mse loss', mse_val_loss)
            if opt.use_imp:
                ps.add_point('val rate value', rate_val_display)
                ps.add_point('val rate loss', rate_val_loss)
                ps.add_point('val total loss', total_val_loss)

            # make plot
            # ps.make_plot('train mse loss', "")   # all epoch share a same img, so give "" to epoch
            # ps.make_plot('cur epoch train mse loss',epoch)
            ps.make_plot('val mse loss')

            if opt.use_imp:
                # ps.make_plot("train rate value","")
                # ps.make_plot("train rate loss","")
                # ps.make_plot("train total loss","")
                ps.make_plot('val rate value')
                ps.make_plot('val rate loss')
                ps.make_plot('val total loss')

            # log sth.
            if opt.use_imp:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                    .format(epoch=epoch,
                            lr=lr,
                            val_mse_loss=mse_val_loss,
                            val_rate_loss=rate_val_loss,
                            val_total_loss=total_val_loss,
                            val_rate_display=rate_val_display))
            else:
                ps.log(
                    'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}'
                    .format(
                        epoch=epoch,
                        lr=lr,
                        # train_mse_loss = mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        model.train()

        # mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        # _ is corresponding Label, compression doesn't use it.
        for idx, (data, _) in enumerate(train_dataloader):
            ipt = Variable(data)

            if opt.use_gpu:
                # if not use_data_parallel:  # because ipt is also target, so we still need to transfer it to CUDA
                ipt = ipt.cuda(async=True)

            optimizer.zero_grad()
            reconstructed, imp_mask_sigmoid = model(ipt)

            # print ('imp_mask_height', model.imp_mask_height)
            # pdb.set_trace()

            # print ('type recons', type(reconstructed.data))
            loss = mse_loss(reconstructed, ipt)
            caffe_loss = loss / (2 * opt.batch_size)

            if opt.use_imp:
                # rate_loss_display = model.imp_mask_sigmoid
                # rate_loss_display = (model.module if use_data_parallel else model).imp_mask_sigmoid
                rate_loss_display = imp_mask_sigmoid
                # rate_loss_display = model.imp_mask
                # print ('type of display', type(rate_loss_display.data))
                rate_loss_ = rate_loss(rate_loss_display)
                # print (
                #     'type of rate_loss_value',
                #     type(rate_loss_value.data)
                # )
                total_loss = caffe_loss + rate_loss_
            else:
                total_loss = caffe_loss

            # 1.
            total_loss.backward()
            # caffe_loss.backward()
            optimizer.step()

            mse_loss_meter.add(caffe_loss.data[0])

            if opt.use_imp:
                rate_loss_meter.add(rate_loss_.data[0])
                rate_display_meter.add(rate_loss_display.data.mean())
                total_loss_meter.add(total_loss.data[0])

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                ps.add_point(
                    'cur epoch train mse loss',
                    mse_loss_meter.value()[0]
                    if opt.print_smooth else caffe_loss.data[0])
                # print (rate_loss_display.data.mean())
                if opt.use_imp:
                    ps.add_point(
                        'train rate value',
                        rate_display_meter.value()[0]
                        if opt.print_smooth else rate_loss_display.data.mean())
                    ps.add_point(
                        'train rate loss',
                        rate_loss_meter.value()[0]
                        if opt.print_smooth else rate_loss_.data[0])
                    ps.add_point(
                        'train total loss',
                        total_loss_meter.value()[0]
                        if opt.print_smooth else total_loss.data[0])
                # pdb.set_trace()
                # progress_bar.set_description('epoch %d/%d, loss = %.2f' % (epoch, opt.max_epoch, total_loss.data[0]))

                #  2.
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], lr))
                # ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), mse_loss_meter.value()[0], lr))

                # ps.log('loss = %f' % caffe_loss.data[0])
                # print(total_loss.data[0])
                # input('waiting......')

                if not opt.use_imp:
                    ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' %
                           (epoch, opt.max_epoch, idx, len(train_dataloader),
                            total_loss_meter.value()[0], lr))
                else:
                    ps.log(
                        'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f'
                        %
                        (epoch, opt.max_epoch, idx, len(train_dataloader),
                         total_loss_meter.value()[0],
                         mse_loss_meter.value()[0], rate_loss_meter.value()[0],
                         rate_display_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    # import pdb
                    pdb.set_trace()

        if use_data_parallel:
            # print (type(model.module))
            # print (model)
            # print (type(model))
            model.module.save(optimizer, epoch)
        else:
            model.save(optimizer, epoch)

        # print ('case error', total_loss.data[0])
        # print ('smoothed error', total_loss_meter.value()[0])

        # plot before val can ease me
        ps.make_plot('train mse loss'
                     )  # all epoch share a same img, so give "" to epoch
        ps.make_plot('cur epoch train mse loss', epoch)
        if opt.use_imp:
            ps.make_plot("train rate value")
            ps.make_plot("train rate loss")
            ps.make_plot("train total loss")

        # val
        if opt.use_imp:
            mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val(
                model, val_dataloader, mse_loss, rate_loss, ps)
        else:
            mse_val_loss = val(model, val_dataloader, mse_loss, None, ps)

        ps.add_point('val mse loss', mse_val_loss)
        if opt.use_imp:
            ps.add_point('val rate value', rate_val_display)
            ps.add_point('val rate loss', rate_val_loss)
            ps.add_point('val total loss', total_val_loss)

        # make plot
        # ps.make_plot('train mse loss', "")   # all epoch share a same img, so give "" to epoch
        # ps.make_plot('cur epoch train mse loss',epoch)
        ps.make_plot('val mse loss')

        if opt.use_imp:
            # ps.make_plot("train rate value","")
            # ps.make_plot("train rate loss","")
            # ps.make_plot("train total loss","")
            ps.make_plot('val rate value')
            ps.make_plot('val rate loss')
            ps.make_plot('val total loss')

        # log sth.
        if opt.use_imp:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\
val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} '
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        train_rate_loss=rate_loss_meter.value()[0],
                        train_total_loss=total_loss_meter.value()[0],
                        train_rate_display=rate_display_meter.value()[0],
                        val_mse_loss=mse_val_loss,
                        val_rate_loss=rate_val_loss,
                        val_total_loss=total_val_loss,
                        val_rate_display=rate_val_display))
        else:
            ps.log(
                'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}'
                .format(epoch=epoch,
                        lr=lr,
                        train_mse_loss=mse_loss_meter.value()[0],
                        val_mse_loss=mse_val_loss))

        # Adaptive adjust lr
        # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高,
        # update learning rate
        # if loss_meter.value()[0] > previous_loss:
        if opt.use_early_adjust:
            if total_loss_meter.value()[0] > previous_loss:
                tolerant_now += 1
                if tolerant_now == opt.tolerant_max:
                    tolerant_now = 0
                    same_lr_epoch = 0
                    lr = lr * opt.lr_decay
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Due to early stop anneal lr to', lr, 'at epoch',
                          epoch)
                    ps.log('Due to early stop anneal lr to %.10f at epoch %d' %
                           (lr, epoch))

            else:
                tolerant_now -= 1

        # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0:
        #     same_lr_epoch = 0
        #     tolerant_now = 0
        #     lr = lr * opt.lr_decay
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = lr
        #     print ('Due to full epochs anneal lr to',lr,'at epoch',epoch)
        #     ps.log ('Due to full epochs anneal lr to %.10f at epoch %d' % (lr, epoch))

        if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file):
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        # previous_loss = total_loss_meter.value()[0] if opt.use_imp else mse_loss_meter.value()[0]
        previous_loss = total_loss_meter.value()[0]
Exemplo n.º 16
0
def val(model, dataloader, mse_loss, rate_loss, ps):
    if ps:
        ps.log('validating ... ')
    else:
        print('run val ... ')
    model.eval()
    # avg_loss = 0
    # progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), ascii=True)
    # print(type(next(iter(dataloader))))
    # print(next(iter(dataloader)))
    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    mse_loss_meter.reset()
    if opt.use_imp:
        rate_display_meter.reset()
        rate_loss_meter.reset()
        total_loss_meter.reset()

    for idx, (data, _) in enumerate(dataloader):
        # ps.log('%.0f%%' % (idx*100.0/len(dataloader)))
        val_input = Variable(data, volatile=True)
        if opt.use_gpu:
            val_input = val_input.cuda(async=True)

        reconstructed, imp_mask_sigmoid = model(val_input)

        batch_loss = mse_loss(reconstructed, val_input)
        batch_caffe_loss = batch_loss / (2 * opt.batch_size)

        if opt.use_imp and rate_loss:
            rate_loss_display = imp_mask_sigmoid
            rate_loss_value = rate_loss(rate_loss_display)
            total_loss = batch_caffe_loss + rate_loss_value

        mse_loss_meter.add(batch_caffe_loss.data[0])

        if opt.use_imp:
            rate_loss_meter.add(rate_loss_value.data[0])
            rate_display_meter.add(rate_loss_display.data.mean())
            total_loss_meter.add(total_loss.data[0])
        # progress_bar.set_description('val_iter %d: loss = %.2f' % (idx+1, batch_caffe_loss.data[0]))
        # progress_bar.set_description('val_iter %d: loss = %.2f' % (idx+1, total_loss_meter.value()[0] if opt.use_imp else mse_loss_meter.value()[0]))

        # avg_loss += batch_caffe_loss.data[0]

    # avg_loss /= len(dataloader)
    # print('avg_loss =', avg_loss)
    # print('meter loss =', loss_meter.value[0])
    # print ('Total avg loss = {}'.format(avg_loss))
    if opt.use_imp:
        return mse_loss_meter.value()[0], rate_loss_meter.value(
        )[0], total_loss_meter.value()[0], rate_display_meter.value()[0]
    else:
        return mse_loss_meter.value()[0]
def test(model, dataloader, test_batch_size, mse_loss, rate_loss):

    model.eval()
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))

    if opt.save_test_img:
        if not os.path.exists(opt.test_imgs_save_path):
            os.makedirs(opt.test_imgs_save_path)
            print('Makedir %s for save test images!' % opt.test_imgs_save_path)

    eval_loss = True  # mse_loss, rate_loss, rate_disp
    eval_on_RGB = True  # RGB_mse, RGB_psnr

    revert_transforms = transforms.Compose([
        transforms.Normalize((-1, -1, -1), (2, 2, 2)),
        # transforms.Normalize([-0.485/0.229, -0.456/0.224, -0.406/0.225], [1/0.229, 1/0.224, 1/0.225]),
        transforms.Lambda(lambda tensor: t.clamp(tensor, 0.0, 1.0)),
        transforms.ToPILImage()
    ])

    mse = lambda x, y: np.mean(np.square(y - x))
    psnr = lambda x, y: 10 * math.log10(255.**2 / mse(x, y))

    mse_meter = AverageValueMeter()
    psnr_meter = AverageValueMeter()
    rate_disp_meter = AverageValueMeter()
    caffe_loss_meter = AverageValueMeter()
    rate_loss_meter = AverageValueMeter()

    if opt.use_imp:
        total_loss_meter = AverageValueMeter()
    else:
        total_loss_meter = caffe_loss_meter

    for idx, data in progress_bar:
        test_data = Variable(data, volatile=True)
        # test_mask = Variable(mask, volatile=True)
        # pdb.set_trace()
        # mask[mask==1] = 0
        # mask[mask==opt.contrastive_degree] = 1
        # print ('type.mask', type(mask))

        # o_mask_as_weight = o_mask.clone()
        # pdb.set_trace()
        # bbox_inner = (o_mask == opt.mse_bbox_weight)
        # bbox_outer = (o_mask == 1)
        # o_mask[bbox_inner] = 1
        # o_mask[bbox_outer] = opt.mse_bbox_weight
        # print (o_mask)
        # pdb.set_trace()
        # o_mask[...] = 1
        # test_o_mask = Variable(o_mask, volatile=True)
        # test_o_mask_as_weight = Variable(o_mask_as_weight, volatile=True)

        # pdb.set_trace()
        if opt.use_gpu:
            test_data = test_data.cuda(async=True)
            # test_mask = test_mask.cuda(async=True)
            # test_o_mask = test_o_mask.cuda(async=True)
            # test_o_mask_as_weight = test_o_mask_as_weight.cuda(async=True)
            # o_mask = o_mask.cuda(async=True)

        # pdb.set_trace()
        if opt.use_imp:
            reconstructed, imp_mask_sigmoid = model(test_data)
        else:
            reconstructed = model(test_data)

        # only save the 1th image of batch
        img_origin = revert_transforms(test_data.data.cpu()[0])
        img_reconstructed = revert_transforms(reconstructed.data.cpu()[0])

        if opt.save_test_img:
            if opt.use_imp:
                imp_map = transforms.ToPILImage()(
                    imp_mask_sigmoid.data.cpu()[0])
                imp_map = imp_map.resize(
                    (imp_map.size[0] * 8, imp_map.size[1] * 8))
                imp_map.save(
                    os.path.join(opt.test_imgs_save_path, "%d_imp.png" % idx))
            img_origin.save(
                os.path.join(opt.test_imgs_save_path, "%d_origin.png" % idx))
            img_reconstructed.save(
                os.path.join(opt.test_imgs_save_path, "%d_reconst.png" % idx))

        if eval_loss:
            mse_loss_v = mse_loss(reconstructed, test_data)

            caffe_loss_v = mse_loss_v / (2 * test_batch_size)

            caffe_loss_meter.add(caffe_loss_v.data[0])
            if opt.use_imp:
                rate_disp_meter.add(imp_mask_sigmoid.data.mean())

                assert rate_loss is not None
                rate_loss_v = rate_loss(imp_mask_sigmoid)
                rate_loss_meter.add(rate_loss_v.data[0])

                total_loss_v = caffe_loss_v + rate_loss_v
                total_loss_meter.add(total_loss_v.data[0])

        if eval_on_RGB:
            origin_arr = np.array(img_origin)
            reconst_arr = np.array(img_reconstructed)
            RGB_mse_v = mse(origin_arr, reconst_arr)
            RGB_psnr_v = psnr(origin_arr, reconst_arr)
            mse_meter.add(RGB_mse_v)
            psnr_meter.add(RGB_psnr_v)

    if eval_loss:
        print(
            'avg_mse_loss = {m_l}, avg_rate_loss = {r_l}, avg_rate_disp = {r_d}, avg_tot_loss = {t_l}'
            .format(m_l=caffe_loss_meter.value()[0],
                    r_l=rate_loss_meter.value()[0],
                    r_d=rate_disp_meter.value()[0],
                    t_l=total_loss_meter.value()[0]))
    if eval_on_RGB:
        print('RGB avg mse = {mse}, RGB avg psnr = {psnr}'.format(
            mse=mse_meter.value()[0], psnr=psnr_meter.value()[0]))
def train(**kwargs):
    opt.parse(kwargs)
    # log file
    ps = PlotSaver("FrozenCNN_ResNet50_RGB_" +
                   time.strftime("%m_%d_%H:%M:%S") + ".log.txt")

    # step1: Model
    compression_model = getattr(models, opt.model)(
        use_imp=opt.use_imp,
        model_name="CWCNN_limu_ImageNet_imp_r={r}_γ={w}_for_resnet50".format(
            r=opt.rate_loss_threshold, w=opt.rate_loss_weight)
        if opt.use_imp else None)

    compression_model.load(None, opt.compression_model_ckpt)
    compression_model.eval()

    # if use_original_RGB:
    #     resnet_50 = resnet50()    # Official ResNet
    # else:
    #     resnet_50 = ResNet50()    # My ResNet
    c_resnet_51 = cResNet51()
    if opt.use_gpu:
        # compression_model.cuda()
        # resnet_50.cuda()
        compression_model = multiple_gpu_process(compression_model)
        c_resnet_51 = multiple_gpu_process(c_resnet_51)

    # freeze the compression network
    for param in compression_model.parameters():
        # print (param.requires_grad)
        param.requires_grad = False

    cudnn.benchmark = True

    # pdb.set_trace()

    # step2: Data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])
    val_data_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    train_data = datasets.ImageFolder(opt.train_data_root,
                                      train_data_transforms)
    val_data = datasets.ImageFolder(opt.val_data_root, val_data_transforms)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                pin_memory=True)

    # step3: criterion and optimizer

    class_loss = t.nn.CrossEntropyLoss()
    lr = opt.lr
    # optimizer = t.optim.Adam(resnet_50.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=opt.weight_decay)
    optimizer = t.optim.SGD(c_resnet_51.parameters(),
                            lr=lr,
                            momentum=opt.momentum,
                            weight_decay=opt.weight_decay)
    start_epoch = 0

    if opt.resume:
        start_epoch = c_resnet_51.module.load(
            None if opt.finetune else optimizer, opt.resume, opt.finetune)

        if opt.finetune:
            print('Finetune from model checkpoint file', opt.resume)
        else:
            print('Resume training from checkpoint file', opt.resume)
            print('Continue training at epoch %d.' % start_epoch)

    # step4: meters
    class_loss_meter = AverageValueMeter()

    class_acc_top5_meter = AverageValueMeter()
    class_acc_top1_meter = AverageValueMeter()

    # class_loss_meter = AverageMeter()
    # class_acc_top5_meter = AverageMeter()
    # class_acc_top1_meter = AverageMeter()

    # ps init

    ps.new_plot('train class loss',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_CE_loss")
    ps.new_plot('val class loss', 1, xlabel="epoch", ylabel="val_CE_loss")

    ps.new_plot('train top_5 acc',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_top5_acc")
    ps.new_plot('train top_1 acc',
                opt.print_freq,
                xlabel="iteration",
                ylabel="train_top1_acc")

    ps.new_plot('val top_5 acc', 1, xlabel="iteration", ylabel="val_top_5_acc")
    ps.new_plot('val top_1 acc', 1, xlabel="iteration", ylabel="val_top_1_acc")

    for epoch in range(start_epoch + 1, opt.max_epoch + 1):
        # per epoch avg loss meter
        class_loss_meter.reset()

        class_acc_top1_meter.reset()
        class_acc_top5_meter.reset()

        # cur_epoch_loss refresh every epoch

        ps.new_plot("cur epoch train class loss",
                    opt.print_freq,
                    xlabel="iteration in cur epoch",
                    ylabel="cur_train_CE_loss")

        c_resnet_51.train()

        for idx, (data, label) in enumerate(train_dataloader):
            ipt = Variable(data)
            label = Variable(label)

            if opt.use_gpu:
                ipt = ipt.cuda()
                label = label.cuda()

            optimizer.zero_grad()
            # if not use_original_RGB:
            # compressed_RGB = compression_model(ipt)
            # else:
            # compressed_RGB = ipt
            # We just wanna compressed features, not to decode this.
            compressed_feat = compression_model(ipt, need_decode=False)
            # print ('RGB', compressed_RGB.requires_grad)
            predicted = c_resnet_51(compressed_feat)

            class_loss_ = class_loss(predicted, label)

            class_loss_.backward()
            optimizer.step()

            class_loss_meter.add(class_loss_.data[0])
            # class_loss_meter.update(class_loss_.data[0], ipt.size(0))

            acc1, acc5 = accuracy(predicted.data, label.data, topk=(1, 5))
            # pdb.set_trace()

            class_acc_top1_meter.add(acc1[0])
            class_acc_top5_meter.add(acc5[0])
            # class_acc_top1_meter.update(acc1[0], ipt.size(0))
            # class_acc_top5_meter.update(acc5[0], ipt.size(0))

            if idx % opt.print_freq == opt.print_freq - 1:
                ps.add_point(
                    'train class loss',
                    class_loss_meter.value()[0]
                    if opt.print_smooth else class_loss_.data[0])
                ps.add_point(
                    'cur epoch train class loss',
                    class_loss_meter.value()[0]
                    if opt.print_smooth else class_loss_.data[0])
                ps.add_point(
                    'train top_5 acc',
                    class_acc_top5_meter.value()[0]
                    if opt.print_smooth else acc5[0])
                ps.add_point(
                    'train top_1 acc',
                    class_acc_top1_meter.value()[0]
                    if opt.print_smooth else acc1[0])

                ps.log(
                    'Epoch %d/%d, Iter %d/%d, class loss = %.4f, top 5 acc = %.2f %%, top 1 acc  = %.2f %%, lr = %.8f'
                    % (epoch, opt.max_epoch, idx, len(train_dataloader),
                       class_loss_meter.value()[0],
                       class_acc_top5_meter.value()[0],
                       class_acc_top1_meter.value()[0], lr))
                # 进入debug模式
                if os.path.exists(opt.debug_file):
                    pdb.set_trace()

        if use_data_parallel:
            c_resnet_51.module.save(optimizer, epoch)

        # plot before val can ease me
        ps.make_plot(
            'train class loss'
        )  # all epoch share a same img, so give ""(default) to epoch
        ps.make_plot('cur epoch train class loss', epoch)
        ps.make_plot("train top_5 acc")
        ps.make_plot("train top_1 acc")

        val_class_loss, val_top5_acc, val_top1_acc = val(
            compression_model, c_resnet_51, val_dataloader, class_loss, None,
            ps)

        ps.add_point('val class loss', val_class_loss)
        ps.add_point('val top_5 acc', val_top5_acc)
        ps.add_point('val top_1 acc', val_top1_acc)

        ps.make_plot('val class loss')
        ps.make_plot('val top_5 acc')
        ps.make_plot('val top_1 acc')

        ps.log(
            'Epoch:{epoch}, lr:{lr}, train_class_loss: {train_class_loss}, train_top5_acc: {train_top5_acc} %, train_top1_acc: {train_top1_acc} %, \
val_class_loss: {val_class_loss}, val_top5_acc: {val_top5_acc} %, val_top1_acc: {val_top1_acc} %'
            .format(epoch=epoch,
                    lr=lr,
                    train_class_loss=class_loss_meter.value()[0],
                    train_top5_acc=class_acc_top5_meter.value()[0],
                    train_top1_acc=class_acc_top1_meter.value()[0],
                    val_class_loss=val_class_loss,
                    val_top5_acc=val_top5_acc,
                    val_top1_acc=val_top1_acc))

        # adjust lr
        if epoch in opt.lr_decay_step_list:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
def val(model, dataloader, mse_loss, rate_loss, ps):
    if ps:
        ps.log('Validating ... ')
    else:
        print('Validating ... ')
    model.eval()

    mse_loss_meter = AverageValueMeter()
    if opt.use_imp:
        rate_loss_meter = AverageValueMeter()
        rate_display_meter = AverageValueMeter()
        total_loss_meter = AverageValueMeter()

    mse_loss_meter.reset()
    if opt.use_imp:
        rate_display_meter.reset()
        rate_loss_meter.reset()
        total_loss_meter.reset()

    for idx, (data, mask, o_mask) in enumerate(dataloader):
        # ps.log('%.0f%%' % (idx*100.0/len(dataloader)))
        val_data = Variable(data, volatile=True)
        val_mask = Variable(mask, volatile=True)
        val_o_mask = Variable(o_mask, volatile=True)
        # pdb.set_trace()
        if opt.use_gpu:
            val_data = val_data.cuda(async=True)
            val_mask = val_mask.cuda(async=True)
            val_o_mask = val_o_mask.cuda(async=True)

        reconstructed, imp_mask_sigmoid = model(val_data, val_mask, val_o_mask)

        batch_loss = mse_loss(reconstructed, val_data, val_o_mask)
        batch_caffe_loss = batch_loss / (2 * opt.batch_size)

        if opt.use_imp and rate_loss:
            rate_loss_display = imp_mask_sigmoid
            rate_loss_value = rate_loss(rate_loss_display, val_mask)
            total_loss = batch_caffe_loss + rate_loss_value

        mse_loss_meter.add(batch_caffe_loss.data[0])

        if opt.use_imp:
            rate_loss_meter.add(rate_loss_value.data[0])
            rate_display_meter.add(rate_loss_display.data.mean())
            total_loss_meter.add(total_loss.data[0])

    if opt.use_imp:
        return mse_loss_meter.value()[0], rate_loss_meter.value(
        )[0], total_loss_meter.value()[0], rate_display_meter.value()[0]
    else:
        return mse_loss_meter.value()[0]
Exemplo n.º 20
0
def main_worker(gpu, save_dir, ngpus_per_node, args):
    # basic setup
    cudnn.benchmark = True
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.log_name is not None:
        log_dir = "runs/%s" % args.log_name
    else:
        log_dir = "runs/time-%d" % time.time()

    if not args.distributed or (args.rank % ngpus_per_node == 0):
        writer = SummaryWriter(logdir=log_dir)
    else:
        writer = None

    if not args.use_latent_flow:  # auto-encoder only
        args.prior_weight = 0
        args.entropy_weight = 0

    # multi-GPU setup
    model = PointFlow(args)
    if args.distributed:  # Multiple processes, single GPU per process
        if args.gpu is not None:

            def _transform_(m):
                return nn.parallel.DistributedDataParallel(
                    m,
                    device_ids=[args.gpu],
                    output_device=args.gpu,
                    check_reduction=True)

            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            model.multi_gpu_wrapper(_transform_)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = 0
        else:
            assert 0, "DistributedDataParallel constructor should always set the single device scope"
    elif args.gpu is not None:  # Single process, single GPU per process
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:  # Single process, multiple GPUs per process

        def _transform_(m):
            return nn.DataParallel(m)

        model = model.cuda()
        model.multi_gpu_wrapper(_transform_)

    # resume checkpoints
    start_epoch = 0
    optimizer = model.make_optimizer(args)
    if args.resume_checkpoint is None and os.path.exists(
            os.path.join(save_dir, 'checkpoint-latest.pt')):
        args.resume_checkpoint = os.path.join(
            save_dir, 'checkpoint-latest.pt')  # use the latest checkpoint
    if args.resume_checkpoint is not None:
        if args.resume_optimizer:
            model, optimizer, start_epoch = resume(
                args.resume_checkpoint,
                model,
                optimizer,
                strict=(not args.resume_non_strict))
        else:
            model, _, start_epoch = resume(args.resume_checkpoint,
                                           model,
                                           optimizer=None,
                                           strict=(not args.resume_non_strict))
        print('Resumed from: ' + args.resume_checkpoint)

    # initialize datasets and loaders
    tr_dataset = MyDataset(args.data_dir, istest=False)
    te_dataset = MyDataset(args.data_dir, istest=True)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            tr_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(dataset=tr_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=0,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               worker_init_fn=init_np_seed)
    test_loader = torch.utils.data.DataLoader(dataset=te_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=True,
                                              drop_last=False,
                                              worker_init_fn=init_np_seed)

    # save dataset statistics
    # if not args.distributed or (args.rank % ngpus_per_node == 0):
    #     np.save(os.path.join(save_dir, "train_set_mean.npy"), tr_dataset.all_points_mean)
    #     np.save(os.path.join(save_dir, "train_set_std.npy"), tr_dataset.all_points_std)
    #     np.save(os.path.join(save_dir, "train_set_idx.npy"), np.array(tr_dataset.shuffle_idx))
    #     np.save(os.path.join(save_dir, "val_set_mean.npy"), te_dataset.all_points_mean)
    #     np.save(os.path.join(save_dir, "val_set_std.npy"), te_dataset.all_points_std)
    #     np.save(os.path.join(save_dir, "val_set_idx.npy"), np.array(te_dataset.shuffle_idx))

    # load classification dataset if needed
    if args.eval_classification:
        from datasets import get_clf_datasets

        def _make_data_loader_(dataset):
            return torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=0,
                                               pin_memory=True,
                                               drop_last=False,
                                               worker_init_fn=init_np_seed)

        clf_datasets = get_clf_datasets(args)
        clf_loaders = {
            k: [_make_data_loader_(ds) for ds in ds_lst]
            for k, ds_lst in clf_datasets.items()
        }
    else:
        clf_loaders = None

    # initialize the learning rate scheduler
    if args.scheduler == 'exponential':
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.exp_decay)
    elif args.scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=args.epochs // 2,
                                              gamma=0.1)
    elif args.scheduler == 'linear':

        def lambda_rule(ep):
            lr_l = 1.0 - max(0, ep - 0.5 * args.epochs) / float(
                0.5 * args.epochs)
            return lr_l

        scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                lr_lambda=lambda_rule)
    else:
        assert 0, "args.schedulers should be either 'exponential' or 'linear'"

    # main training loop
    start_time = time.time()
    entropy_avg_meter = AverageValueMeter()
    latent_nats_avg_meter = AverageValueMeter()
    point_nats_avg_meter = AverageValueMeter()
    if args.distributed:
        print("[Rank %d] World size : %d" % (args.rank, dist.get_world_size()))

    print("Start epoch: %d End epoch: %d" % (start_epoch, args.epochs))
    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # adjust the learning rate
        if (epoch + 1) % args.exp_decay_freq == 0:
            scheduler.step(epoch=epoch)
            if writer is not None:
                writer.add_scalar('lr/optimizer', scheduler.get_lr()[0], epoch)

        # train for one epoch
        for bidx, data in enumerate(train_loader):
            idx_batch, tr_batch, te_batch = data['idx'], data[
                'train_points'], data['test_points']
            step = bidx + len(train_loader) * epoch
            model.train()
            inputs = tr_batch.cuda(args.gpu, non_blocking=True)
            out = model(inputs, optimizer, step, writer)
            entropy, prior_nats, recon_nats = out['entropy'], out[
                'prior_nats'], out['recon_nats']
            entropy_avg_meter.update(entropy)
            point_nats_avg_meter.update(recon_nats)
            latent_nats_avg_meter.update(prior_nats)
            if step % args.log_freq == 0:
                duration = time.time() - start_time
                start_time = time.time()
                print(
                    "[Rank %d] Epoch %d Batch [%2d/%2d] Time [%3.2fs] Entropy %2.5f LatentNats %2.5f PointNats %2.5f"
                    % (args.rank, epoch, bidx, len(train_loader), duration,
                       entropy_avg_meter.avg, latent_nats_avg_meter.avg,
                       point_nats_avg_meter.avg))

        # evaluate on the validation set
        # if not args.no_validation and (epoch + 1) % args.val_freq == 0:
        #     from utils import validate
        #     validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=clf_loaders)

        # save visualizations
        if (epoch + 1) % args.viz_freq == 0:
            # reconstructions
            model.eval()
            samples = model.reconstruct(inputs)
            results = []
            for idx in range(min(10, inputs.size(0))):
                res = visualize_point_clouds(samples[idx], inputs[idx], idx)
                results.append(res)
            res = np.concatenate(results, axis=1)
            scipy.misc.imsave(
                os.path.join(
                    save_dir, 'images',
                    'tr_vis_conditioned_epoch%d-gpu%s.png' %
                    (epoch, args.gpu)), res.transpose((1, 2, 0)))
            if writer is not None:
                writer.add_image('tr_vis/conditioned', torch.as_tensor(res),
                                 epoch)

            # samples
            if args.use_latent_flow:
                num_samples = min(10, inputs.size(0))
                num_points = inputs.size(1)
                _, samples = model.sample(num_samples, num_points)
                results = []
                for idx in range(num_samples):
                    res = visualize_point_clouds(samples[idx], inputs[idx],
                                                 idx)
                    results.append(res)
                res = np.concatenate(results, axis=1)
                scipy.misc.imsave(
                    os.path.join(
                        save_dir, 'images',
                        'tr_vis_conditioned_epoch%d-gpu%s.png' %
                        (epoch, args.gpu)), res.transpose((1, 2, 0)))
                if writer is not None:
                    writer.add_image('tr_vis/sampled', torch.as_tensor(res),
                                     epoch)

        # save checkpoints
        if not args.distributed or (args.rank % ngpus_per_node == 0):
            if (epoch + 1) % args.save_freq == 0:
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-%d.pt' % epoch))
                save(model, optimizer, epoch + 1,
                     os.path.join(save_dir, 'checkpoint-latest.pt'))
Exemplo n.º 21
0
                              depth_map_size=depth_map_size)

network = network.cuda()  # move network to GPU
# Load trained model
try:
    network.load_state_dict(torch.load(opt.model))
    print('Trained net weights loaded')
except:
    print('ERROR: Failed to load net weights')
    exit()
network.eval()
# ========================================================== #

# =============DEFINE stuff for logs ======================================== #
# Overall average metrics
overall_chd_loss = AverageValueMeter()
overall_iou_loss = AverageValueMeter()
overall_f_score_5_percent = AverageValueMeter()

# Per shape category metrics
per_cat_items = defaultdict(lambda: 0)
per_cat_chd_loss = defaultdict(lambda: AverageValueMeter())
per_cat_iou_loss = defaultdict(lambda: AverageValueMeter())
per_cat_f_score_5_percent = defaultdict(lambda: AverageValueMeter())

if not os.path.exists(opt.model[:-4]):
    os.mkdir(opt.model[:-4])
    print(f'created dir {opt.model[:-4]}/ for saving outputs')
output_folder = opt.model[:-4]
# ========================================================== #