示例#1
0
    def __init__(self, config):
        self.config = config
        self.summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr)

        ## model
        print(toGreen('Loading Model...'))
        self.model = create_model(config)
        self.model.print()

        ## checkpoint manager
        self.ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode,
                                         config.max_ckpt_num)
        if config.load_pretrained:
            print(toGreen('Loading pretrained Model...'))
            load_result = self.model.get_network().load_state_dict(
                torch.load(os.path.join('./ckpt', config.pre_ckpt_name)))
            lr = config.lr_init * (config.decay_rate**(config.epoch_start //
                                                       config.decay_every))
            print(toRed('\tlearning rate: {}'.format(lr)))
            print(toRed('\tload result: {}'.format(load_result)))
            self.model.set_optim(lr)

        ## training vars
        self.max_epoch = 10000
        self.epoch_range = np.arange(config.epoch_start, self.max_epoch)
        self.itr_global = 0
        self.itr = 0
        self.err_epoch_train = 0
        self.err_epoch_test = 0
示例#2
0
def evaluate(config, mode):
    date = datetime.datetime.now().strftime('%Y_%m_%d/%H-%M')
    save_path = os.path.join(config.LOG_DIR.save, date, config.eval_mode)
    exists_or_mkdir(save_path)

    print(toGreen('Loading checkpoint manager...'))
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, 10)
    #ckpt_manager = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode, 10)

    batch_size = config.batch_size
    sample_num = config.sample_num
    skip_length = np.array(config.skip_length)

    ## DEFINE SESSION
    sess = tf.Session(config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False))

    ## DEFINE MODEL
    print(toGreen('Building model...'))
    num_control_points = 5
    param_dim = num_control_points ** 2
    inputs, outputs = get_evaluation_model(sample_num, param_dim, num_control_points, config.height, config.width)

    ## INITIALIZING VARIABLE
    print(toGreen('Initializing variables'))
    sess.run(tf.global_variables_initializer())
    print(toGreen('Loading checkpoint...'))
    ckpt_manager.load_ckpt(sess, by_score = config.load_ckpt_by_score)

    print(toYellow('======== EVALUATION START ========='))
    offset = '/data1/junyonglee/video_stab/eval'
    file_path = os.path.join(offset, 'train_unstab')
    test_video_list = np.array(sorted(tl.files.load_file_list(path = file_path, regx = '.*', printable = False)))
    for k in np.arange(len(test_video_list)):
        test_video_name = test_video_list[k]
        eval_path_stab = os.path.join(offset, 'train_stab')
        eval_path_unstab = os.path.join(offset, 'train_unstab')

        cap_stab = cv2.VideoCapture(os.path.join(eval_path_stab, test_video_name))
        cap_unstab = cv2.VideoCapture(os.path.join(eval_path_unstab, test_video_name))

        total_frame_num = int(cap_unstab.get(7))
        base = os.path.basename(test_video_name)
        base_name = os.path.splitext(base)[0]

        fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
        fps = cap_unstab.get(5)
        out = cv2.VideoWriter(os.path.join(save_path, str(k) + '_' + config.eval_mode + '_' + base_name + '_out.avi'), fourcc, fps, (2 * config.width, config.height))
        print(toYellow('reading filename: {}, total frame: {}'.format(test_video_name, total_frame_num)))

        # read frame
        def read_frame(cap):
            ref, frame = cap.read()
            if ref != False:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame / 255., (config.width, config.height))
            return ref, frame

        # reading all frames in the video
        total_frames_stab = []
        total_frames_unstab = []
        print(toGreen('reading all frames...'))
        while True:
            ref_stab, frame_stab = read_frame(cap_stab)
            ref_unstab, frame_unstab = read_frame(cap_unstab)

            if ref_stab == False or ref_unstab == False:
                break

            total_frames_stab.append(frame_stab)
            total_frames_unstab.append(frame_unstab)

        # duplicate first frames 32 times
        for i in np.arange(skip_length[-1] - skip_length[0]):
            total_frames_unstab[i] = total_frames_stab[i]

        print(toGreen('stabilizaing video...'))
        total_frame_num = len(total_frames_unstab)
        total_frames_stab = np.array(total_frames_stab)
        total_frames_unstab = np.array(total_frames_unstab)

        sample_idx = skip_length
        for frame_idx in range(skip_length[-1] - skip_length[0], total_frame_num):

            batch_frames = total_frames_unstab[sample_idx]
            batch_frames = np.expand_dims(np.concatenate(batch_frames, axis = 2), axis = 0)

            feed_dict = {
                inputs['patches_t']: batch_frames,
                inputs['u_t']: batch_frames[:, :, :, 18:]
            }
            s_t_pred = np.squeeze(sess.run(outputs['s_t_pred'], feed_dict))

            output = np.uint8(np.concatenate([total_frames_unstab[sample_idx[-1]].copy(), s_t_pred], axis = 1) * 255.)
            output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
            out.write(np.uint8(output))

            total_frames_unstab[sample_idx[-1]] = total_frames_stab[sample_idx[-1]]

            print('{}/{} {}/{} frame index: {}'.format(k + 1, len(test_video_list), frame_idx, int(total_frame_num - 1), sample_idx), flush = True)
            sample_idx = sample_idx + 1

        cap_stab.release()
        cap_unstab.release()
        out.release()
示例#3
0
def train(config):
    summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr)

    ## inputs
    inputs = {'b_t_1': None, 'b_t': None, 's_t_1': None, 's_t': None}
    inputs = collections.OrderedDict(sorted(inputs.items(),
                                            key=lambda t: t[0]))

    ## model
    print(toGreen('Loading Model...'))
    moduleNetwork = Network().to(device)
    moduleNetwork.apply(weights_init)
    moduleNetwork_gt = Network().to(device)
    print(moduleNetwork)

    ## checkpoint manager
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode,
                                config.max_ckpt_num)
    moduleNetwork.load_state_dict(
        torch.load('./network/network-default.pytorch'))
    moduleNetwork_gt.load_state_dict(
        torch.load('./network/network-default.pytorch'))

    ## data loader
    print(toGreen('Loading Data Loader...'))
    data_loader = Data_Loader(config,
                              is_train=True,
                              name='train',
                              thread_num=config.thread_num)
    data_loader_test = Data_Loader(config,
                                   is_train=False,
                                   name="test",
                                   thread_num=config.thread_num)

    data_loader.init_data_loader(inputs)
    data_loader_test.init_data_loader(inputs)

    ## loss, optim
    print(toGreen('Building Loss & Optim...'))
    MSE_sum = torch.nn.MSELoss(reduction='sum')
    MSE_mean = torch.nn.MSELoss()
    optimizer = optim.Adam(moduleNetwork.parameters(),
                           lr=config.lr_init,
                           betas=(config.beta1, 0.999))
    errs = collections.OrderedDict()

    print(toYellow('======== TRAINING START ========='))
    max_epoch = 10000
    itr = 0
    for epoch in np.arange(max_epoch):

        # train
        while True:
            itr_time = time.time()

            inputs, is_end = data_loader.get_feed()
            if is_end: break

            if config.loss == 'image':
                flow_bb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_bs = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_sb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)

                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)
                    s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like(
                        inputs['s_t_1'], device=device),
                                                 tensorFlow=flow_ss_gt)

                s_t_warped_bb = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_bb)
                s_t_warped_bs = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_bs)
                s_t_warped_sb = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_sb)
                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)

                s_t_warped_bb_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_bb)
                s_t_warped_bs_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_bs)
                s_t_warped_sb_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_sb)
                s_t_warped_ss_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_ss)

                optimizer.zero_grad()

                errs['MSE_bb'] = MSE_sum(
                    s_t_warped_bb * s_t_warped_bb_mask,
                    inputs['s_t']) / s_t_warped_bb_mask.sum()
                errs['MSE_bs'] = MSE_sum(
                    s_t_warped_bs * s_t_warped_bs_mask,
                    inputs['s_t']) / s_t_warped_bs_mask.sum()
                errs['MSE_sb'] = MSE_sum(
                    s_t_warped_sb * s_t_warped_sb_mask,
                    inputs['s_t']) / s_t_warped_sb_mask.sum()
                errs['MSE_ss'] = MSE_sum(
                    s_t_warped_ss * s_t_warped_ss_mask,
                    inputs['s_t']) / s_t_warped_ss_mask.sum()

                errs['MSE_bb_mask_shape'] = MSE_mean(s_t_warped_bb_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_bs_mask_shape'] = MSE_mean(s_t_warped_bs_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_sb_mask_shape'] = MSE_mean(s_t_warped_sb_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask,
                                                     s_t_warped_ss_mask_gt)

                errs['total'] = errs['MSE_bb'] + errs['MSE_bs'] + errs['MSE_sb'] + errs['MSE_ss'] \
                              + errs['MSE_bb_mask_shape'] + errs['MSE_bs_mask_shape'] + errs['MSE_sb_mask_shape'] + errs['MSE_ss_mask_shape']

            if config.loss == 'image_ss':
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)
                    s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like(
                        inputs['s_t_1'], device=device),
                                                 tensorFlow=flow_ss_gt)

                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)
                s_t_warped_ss_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_ss)

                optimizer.zero_grad()

                errs['MSE_ss'] = MSE_sum(
                    s_t_warped_ss * s_t_warped_ss_mask,
                    inputs['s_t']) / s_t_warped_ss_mask.sum()
                errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['total'] = errs['MSE_ss'] + errs['MSE_ss_mask_shape']

            if config.loss == 'flow_only':
                flow_bb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_bs = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_sb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)

                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)

                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)

                optimizer.zero_grad()

                # liteflow_flow_only
                errs['MSE_bb_ss'] = MSE_mean(flow_bb, flow_ss_gt)
                errs['MSE_bs_ss'] = MSE_mean(flow_bs, flow_ss_gt)
                errs['MSE_sb_ss'] = MSE_mean(flow_sb, flow_ss_gt)
                errs['MSE_ss_ss'] = MSE_mean(flow_ss, flow_ss_gt)
                errs['total'] = errs['MSE_bb_ss'] + errs['MSE_bs_ss'] + errs[
                    'MSE_sb_ss'] + errs['MSE_ss_ss']

            errs['total'].backward()
            optimizer.step()

            lr = adjust_learning_rate(optimizer, epoch, config.decay_rate,
                                      config.decay_every, config.lr_init)

            if itr % config.write_log_every_itr == 0:
                summary.add_scalar('loss/loss_mse', errs['total'].item(), itr)
                vutils.save_image(inputs['s_t_1'].detach().cpu(),
                                  '{}/{}_1_input.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)
                vutils.save_image(s_t_warped_ss.detach().cpu(),
                                  '{}/{}_2_warped_ss.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)
                vutils.save_image(inputs['s_t'].detach().cpu(),
                                  '{}/{}_3_gt.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)

                if config.loss == 'image_ss':
                    vutils.save_image(s_t_warped_ss_mask.detach().cpu(),
                                      '{}/{}_4_s_t_wapred_ss_mask.png'.format(
                                          config.LOG_DIR.sample, itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)
                elif config.loss != 'flow_only':
                    vutils.save_image(s_t_warped_bb_mask.detach().cpu(),
                                      '{}/{}_4_s_t_wapred_bb_mask.png'.format(
                                          config.LOG_DIR.sample, itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)

            if itr % config.refresh_image_log_every_itr == 0:
                remove_file_end_with(config.LOG_DIR.sample, '*.png')

            print_logs('TRAIN',
                       config.mode,
                       epoch,
                       itr_time,
                       itr,
                       data_loader.num_itr,
                       errs=errs,
                       lr=lr)
            itr += 1

        if epoch % config.write_ckpt_every_epoch == 0:
            ckpt_manager.save_ckpt(moduleNetwork,
                                   epoch,
                                   score=errs['total'].item())
示例#4
0
class Trainer():
    def __init__(self, config):
        self.config = config
        self.summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr)

        ## model
        print(toGreen('Loading Model...'))
        self.model = create_model(config)
        self.model.print()

        ## checkpoint manager
        self.ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode,
                                         config.max_ckpt_num)
        if config.load_pretrained:
            print(toGreen('Loading pretrained Model...'))
            load_result = self.model.get_network().load_state_dict(
                torch.load(os.path.join('./ckpt', config.pre_ckpt_name)))
            lr = config.lr_init * (config.decay_rate**(config.epoch_start //
                                                       config.decay_every))
            print(toRed('\tlearning rate: {}'.format(lr)))
            print(toRed('\tload result: {}'.format(load_result)))
            self.model.set_optim(lr)

        ## training vars
        self.max_epoch = 10000
        self.epoch_range = np.arange(config.epoch_start, self.max_epoch)
        self.itr_global = 0
        self.itr = 0
        self.err_epoch_train = 0
        self.err_epoch_test = 0

    def train(self):
        print(toYellow('======== TRAINING START ========='))
        for epoch in self.epoch_range:
            ## TRAIN ##
            self.itr = 0
            self.err_epoch_train = 0
            self.model.train()
            while True:
                if self.iteration(epoch): break
            err_epoch_train = self.err_epoch_train / self.itr

            ## TEST ##
            self.itr = 0
            self.err_epoch_test = 0
            self.model.eval()
            while True:
                with torch.no_grad():
                    if self.iteration(epoch, is_train=False): break
            err_epoch_test = self.err_epoch_test / self.itr

            ## LOG
            if epoch % self.config.write_ckpt_every_epoch == 0:
                self.ckpt_manager.save_ckpt(self.model.get_network(),
                                            epoch + 1,
                                            score=err_epoch_train)
                remove_file_end_with(self.config.LOG_DIR.sample, '*.png')

            self.summary.add_scalar('loss/epoch_train', err_epoch_train, epoch)
            self.summary.add_scalar('loss/epoch_test', err_epoch_test, epoch)

    def iteration(self, epoch, is_train=True):
        lr = None
        itr_time = time.time()

        is_end = self.model.iteration(epoch, is_train)
        if is_end: return True

        inputs = self.model.visuals['inputs']
        errs = self.model.visuals['errs']
        outs = self.model.visuals['outs']
        lr = self.model.visuals['lr']
        num_itr = self.model.visuals['num_itr']

        if is_train:
            lr = self.model.update(epoch)

            if self.itr % config.write_log_every_itr == 0:
                try:
                    self.summary.add_scalar('loss/itr', errs['total'].item(),
                                            self.itr_global)
                    vutils.save_image(inputs['input'],
                                      '{}/{}_{}_1_input.png'.format(
                                          self.config.LOG_DIR.sample, epoch,
                                          self.itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)
                    i = 2
                    for key, val in outs.items():
                        if val is not None:
                            vutils.save_image(val,
                                              '{}/{}_{}_{}_out_{}.png'.format(
                                                  self.config.LOG_DIR.sample,
                                                  epoch, self.itr, i, key),
                                              nrow=3,
                                              padding=0,
                                              normalize=False)
                        i += 1

                    vutils.save_image(inputs['gt'],
                                      '{}/{}_{}_{}_gt.png'.format(
                                          self.config.LOG_DIR.sample, epoch,
                                          self.itr, i),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)
                except Exception as ex:
                    print('saving error: ', ex)

            print_logs('TRAIN',
                       self.config.mode,
                       epoch,
                       itr_time,
                       self.itr,
                       num_itr,
                       errs=errs,
                       lr=lr)
            self.err_epoch_train += errs['total'].item()
            self.itr += 1
            self.itr_global += 1
        else:
            print_logs('TEST',
                       self.config.mode,
                       epoch,
                       itr_time,
                       self.itr,
                       num_itr,
                       errs=errs)
            self.err_epoch_test += errs['total'].item()
            self.itr += 1

        gc.collect()
        return is_end
示例#5
0
def evaluate(config, mode):
    date = datetime.datetime.now().strftime('%Y_%m_%d/%H-%M')
    save_path = os.path.join(config.LOG_DIR.save, date, config.eval_mode)
    exists_or_mkdir(save_path)

    print(toGreen('Loading checkpoint manager...'))
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, 10)
    #ckpt_manager = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode, 10)

    batch_size = config.batch_size
    sample_num = config.sample_num
    skip_length = config.skip_length

    ## DEFINE SESSION
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))

    ## DEFINE MODEL
    print(toGreen('Building model...'))
    stabNet = StabNet(config.height,
                      config.width,
                      config.F_dim,
                      is_train=False)
    stable_path_net = stabNet.get_stable_path_init(sample_num)
    unstable_path_net = stabNet.get_unstable_path_init(sample_num)
    outputs_net = stabNet.init_evaluation_model(sample_num)

    ## INITIALIZING VARIABLE
    print(toGreen('Initializing variables'))
    sess.run(tf.global_variables_initializer())
    print(toGreen('Loading checkpoint...'))
    ckpt_manager.load_ckpt(sess, by_score=config.load_ckpt_by_score)

    print(toYellow('======== EVALUATION START ========='))
    test_video_list = np.array(
        sorted(
            tl.files.load_file_list(path=config.unstab_path,
                                    regx='.*',
                                    printable=False)))
    for k in np.arange(len(test_video_list)):
        test_video_name = test_video_list[k]
        cap = cv2.VideoCapture(
            os.path.join(config.unstab_path, test_video_name))
        fps = cap.get(5)
        resize_h = config.height
        resize_w = config.width
        # out_h = int(cap.get(4))
        # out_w = int(cap.get(3))
        out_h = resize_h
        out_w = resize_w

        # refine_temp = np.ones((h, w))
        # refine_temp = refine_image(refine_temp)
        # [h, w] = refine_temp.shape[:2]
        total_frame_num = int(cap.get(7))
        fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
        base = os.path.basename(test_video_name)
        base_name = os.path.splitext(base)[0]

        out = cv2.VideoWriter(
            os.path.join(
                save_path,
                str(k) + '_' + config.eval_mode + '_' + base_name +
                '_out.avi'), fourcc, fps, (3 * out_w, out_h))
        print(
            toYellow('reading filename: {}, total frame: {}'.format(
                test_video_name, total_frame_num)))

        # read frame
        def refine_frame(frame):
            return cv2.resize(frame / 255., (resize_w, resize_h))

        def read_frame(cap):
            ref, frame = cap.read()
            if ref != False:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (out_w, out_h))
            return ref, frame

        # reading all frames in the video
        total_frames = []
        print(toGreen('reading all frames...'))
        while True:
            ref, frame = read_frame(cap)
            if ref == False:
                break
            total_frames.append(refine_frame(frame))

        # duplicate first frames 30 times
        for i in np.arange((sample_num - 1) * skip_length):
            total_frames.insert(0, total_frames[0])

        print(toGreen('stabilizaing video...'))
        total_frame_num = len(total_frames)
        total_frames = np.array(total_frames)

        S = [None] * total_frame_num
        U = [None] * total_frame_num
        C_0_list = [None] * total_frame_num
        S_t_1_seq = None
        U_t_1_seq = None

        for i in np.arange((sample_num - 1) * skip_length):
            C_0_list[i] = np.zeros([1, out_h, out_w, 2])

        sample_idx = np.arange(0, 0 + sample_num * skip_length, skip_length)
        for frame_idx in range((sample_num - 1) * skip_length,
                               total_frame_num):

            batch_frames = total_frames[sample_idx]
            batch_frames = np.expand_dims(np.concatenate(np.expand_dims(
                batch_frames, axis=0),
                                                         axis=0),
                                          axis=0)

            if U[sample_idx[0]] is None:
                feed_dict = {
                    stabNet.inputs['IS']: batch_frames[:, :-1, :, :, :]
                }
                stable_path_init = sess.run(stable_path_net, feed_dict)

                feed_dict = {
                    stabNet.inputs['IU']: batch_frames[:, :-1, :, :, :]
                }
                unstable_path_init = sess.run(unstable_path_net, feed_dict)

                S_t_1_seq = stable_path_init['S_t_1_seq']
                U_t_1_seq = unstable_path_init['U_t_1_seq']

                for i in np.arange(S_t_1_seq.shape[1]):
                    S[sample_idx[i + 1]] = S_t_1_seq[:, i:i + 1, :, :, :]
                    U[sample_idx[i + 1]] = U_t_1_seq[:, i:i + 1, :, :, :]

            idxs = sample_idx[1:-1]
            i = 0
            for idx in idxs:
                if i == 0:
                    S_t_1_seq = S[idx]
                    U_t_1_seq = U[idx]
                else:
                    S_t_1_seq = np.concatenate([S_t_1_seq, S[idx]], axis=1)
                    U_t_1_seq = np.concatenate([U_t_1_seq, U[idx]], axis=1)
                i += 1
            C_0 = C_0_list[sample_idx[0]]

            feed_dict = {
                stabNet.inputs['IU']:
                batch_frames[:, :-1, :, :, :],
                stabNet.inputs['Iu']:
                np.expand_dims(batch_frames[:, -1, :, :, :], axis=1),
                stabNet.inputs['U_t_1_seq']:
                U_t_1_seq,
                stabNet.inputs['S_t_1_seq']:
                S_t_1_seq,
                stabNet.inputs['C_0']:
                C_0,
            }
            Is_pred, Is_pred_wo_C0, S_t_pred, U_t, C_0 = sess.run([
                outputs_net['Is_pred'], outputs_net['Is_pred_wo_C0'],
                outputs_net['S_t_pred_seq'], outputs_net['U_t_seq'],
                outputs_net['B_t_wo_C0']
            ], feed_dict)

            S[sample_idx[-1]] = S_t_pred
            U[sample_idx[-1]] = U_t
            C_0_list[sample_idx[-1]] = C_0

            Is_pred = np.squeeze(Is_pred)
            Is_pred_wo_C0 = np.squeeze(Is_pred_wo_C0)

            output = np.uint8(
                np.concatenate(
                    (total_frames[frame_idx].copy(), Is_pred, Is_pred_wo_C0),
                    axis=1) * 255.)
            output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
            out.write(np.uint8(output))

            print('{}/{} {}/{} frame index: {}'.format(
                k + 1, len(test_video_list), frame_idx,
                int(total_frame_num - 1), sample_idx),
                  flush=True)
            sample_idx = sample_idx + 1

        cap.release()
        out.release()
示例#6
0
def train(config, mode):
    ## Managers
    print(toGreen('Loading checkpoint manager...'))
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, mode, config.max_ckpt_num)
    ckpt_manager_itr = CKPT_Manager(config.LOG_DIR.ckpt_itr, mode,
                                    config.max_ckpt_num)
    ckpt_manager_init = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode,
                                     config.max_ckpt_num)
    ckpt_manager_init_itr = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_itr,
                                         mode, config.max_ckpt_num)
    ckpt_manager_perm = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_perm, mode,
                                     1)

    ## DEFINE SESSION
    seed_value = 1
    tf.set_random_seed(seed_value)
    np.random.seed(seed_value)
    random.seed(seed_value)

    print(toGreen('Initializing session...'))
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))

    ## DEFINE MODEL
    stabNet = StabNet(config.height, config.width)

    ## DEFINE DATA LOADERS
    print(toGreen('Loading dataloader...'))
    data_loader = Data_Loader(config,
                              is_train=True,
                              thread_num=config.thread_num)
    data_loader_test = Data_Loader(config.TEST,
                                   is_train=False,
                                   thread_num=config.thread_num)

    ## DEFINE TRAINER
    print(toGreen('Initializing Trainer...'))
    trainer = Trainer(stabNet, [data_loader, data_loader_test], config)

    ## DEFINE SUMMARY WRITER
    print(toGreen('Building summary writer...'))
    if config.is_pretrain:
        writer_scalar_itr_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_scalar_train_itr,
            flush_secs=30,
            filename_suffix='.scalor_log_itr_init')
        writer_scalar_epoch_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_scalar_train_epoch,
            flush_secs=30,
            filename_suffix='.scalor_log_epoch_init')
        writer_scalar_epoch_valid_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_scalar_valid,
            flush_secs=30,
            filename_suffix='.scalor_log_epoch_test_init')
        writer_image_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_image,
            flush_secs=30,
            filename_suffix='.image_log_init')

    writer_scalar_itr = tf.summary.FileWriter(
        config.LOG_DIR.log_scalar_train_itr,
        flush_secs=30,
        filename_suffix='.scalor_log_itr')
    writer_scalar_epoch = tf.summary.FileWriter(
        config.LOG_DIR.log_scalar_train_epoch,
        flush_secs=30,
        filename_suffix='.scalor_log_epoch')
    writer_scalar_epoch_valid = tf.summary.FileWriter(
        config.LOG_DIR.log_scalar_valid,
        flush_secs=30,
        filename_suffix='.scalor_log_epoch_test')
    writer_image = tf.summary.FileWriter(config.LOG_DIR.log_image,
                                         flush_secs=30,
                                         filename_suffix='.image_log')

    ## INITIALIZE SESSION
    print(toGreen('Initializing network...'))
    sess.run(tf.global_variables_initializer())
    trainer.init_vars(sess)

    ckpt_manager_init.load_ckpt(sess, by_score=False)
    ckpt_manager_perm.load_ckpt(sess)

    if config.is_pretrain:
        print(toYellow('======== PRETRAINING START ========='))
        global_step = 0
        for epoch in range(0, config.PRETRAIN.n_epoch):
            #for epoch in range(0, 1):

            # update learning rate
            trainer.update_learning_rate(epoch, config.PRETRAIN.lr_init,
                                         config.PRETRAIN.lr_decay_rate,
                                         config.PRETRAIN.decay_every, sess)

            errs_total_pretrain = collections.OrderedDict.fromkeys(
                trainer.pretrain_loss.keys(), 0.)
            errs = None
            epoch_time = time.time()
            idx = 0
            while True:
                #for idx in range(0, 2):
                step_time = time.time()

                feed_dict, is_end = data_loader.feed_the_network()
                if is_end: break
                feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
                _, lr, errs = sess.run([
                    trainer.optim_init, trainer.learning_rate,
                    trainer.pretrain_loss
                ], feed_dict)
                errs_total_pretrain = dict_operations(errs_total_pretrain, '+',
                                                      errs)

                if global_step % config.write_log_every_itr == 0:
                    summary_loss_itr, summary_image = sess.run(
                        [trainer.scalar_sum_itr_init, trainer.image_sum_init],
                        feed_dict)
                    writer_scalar_itr_init.add_summary(summary_loss_itr,
                                                       global_step)
                    writer_image_init.add_summary(summary_image, global_step)

                # save checkpoint
                if (global_step) % config.PRETRAIN.write_ckpt_every_itr == 0:
                    ckpt_manager_init_itr.save_ckpt(
                        sess,
                        trainer.pretraining_save_vars,
                        '{:05d}_{:05d}'.format(epoch, global_step),
                        score=errs_total_pretrain['total'] / (idx + 1))

                print_logs('PRETRAIN',
                           mode,
                           epoch,
                           step_time,
                           idx,
                           data_loader.num_itr,
                           errs=errs,
                           coefs=trainer.coef_container,
                           lr=lr)

                global_step += 1
                idx += 1

            # save log
            errs_total_pretrain = dict_operations(errs_total_pretrain, '/',
                                                  data_loader.num_itr)
            summary_loss_epoch_init = sess.run(
                trainer.summary_epoch_init,
                feed_dict=dict_operations(trainer.loss_epoch_init_placeholder,
                                          '=', errs_total_pretrain))
            writer_scalar_epoch_init.add_summary(summary_loss_epoch_init,
                                                 epoch)

            ## TEST
            errs_total_pretrain_test = collections.OrderedDict.fromkeys(
                trainer.pretrain_loss_test.keys(), 0.)
            errs = None
            epoch_time_test = time.time()
            idx = 0
            while True:
                #for idx in range(0, 2):
                step_time = time.time()

                feed_dict, is_end = data_loader_test.feed_the_network()
                if is_end: break
                feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
                errs = sess.run(trainer.pretrain_loss_test, feed_dict)

                errs_total_pretrain_test = dict_operations(
                    errs_total_pretrain_test, '+', errs)

                print_logs('PRETRAIN TEST',
                           mode,
                           epoch,
                           step_time,
                           idx,
                           data_loader_test.num_itr,
                           errs=errs,
                           coefs=trainer.coef_container)
                idx += 1

            # save log
            errs_total_pretrain_test = dict_operations(
                errs_total_pretrain_test, '/', data_loader_test.num_itr)
            summary_loss_test_init = sess.run(
                trainer.summary_epoch_init,
                feed_dict=dict_operations(trainer.loss_epoch_init_placeholder,
                                          '=', errs_total_pretrain_test))
            writer_scalar_epoch_valid_init.add_summary(summary_loss_test_init,
                                                       epoch)

            print_logs('TRAIN SUMMARY',
                       mode,
                       epoch,
                       epoch_time,
                       errs=errs_total_pretrain)
            print_logs('TEST SUMMARY',
                       mode,
                       epoch,
                       epoch_time_test,
                       errs=errs_total_pretrain_test)

            # save checkpoint
            if epoch % config.write_ckpt_every_epoch == 0:
                ckpt_manager_init.save_ckpt(
                    sess,
                    trainer.pretraining_save_vars,
                    epoch,
                    score=errs_total_pretrain_test['total'])

            # reset image log
            if epoch % config.refresh_image_log_every_epoch == 0:
                writer_image_init.close()
                remove_file_end_with(config.PRETRAIN.LOG_DIR.log_image,
                                     '*.image_log')
                writer_image_init.reopen()

        if config.pretrain_only:
            return
        else:
            data_loader.reset_to_train_input(stabNet)
            data_loader_test.reset_to_train_input(stabNet)

    print(toYellow('========== TRAINING START =========='))
    global_step = 0
    for epoch in range(0, config.n_epoch):
        #for epoch in range(0, 1):
        # update learning rate
        trainer.update_learning_rate(epoch, config.lr_init,
                                     config.lr_decay_rate, config.decay_every,
                                     sess)

        ## TRAIN
        errs_total_train = collections.OrderedDict.fromkeys(
            trainer.loss.keys(), 0.)
        errs = None
        epoch_time = time.time()
        idx = 0
        #while True:
        for idx in range(0, 2):
            step_time = time.time()

            feed_dict, is_end = data_loader.feed_the_network()
            if is_end: break
            feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
            _, lr, errs = sess.run(
                [trainer.optim_main, trainer.learning_rate, trainer.loss],
                feed_dict)

            errs_total_train = dict_operations(errs_total_train, '+', errs)

            if global_step % config.write_ckpt_every_itr == 0:
                ckpt_manager_itr.save_ckpt(
                    sess,
                    trainer.save_vars,
                    '{:05d}_{:05d}'.format(epoch, global_step),
                    score=errs_total_train['total'] / (idx + 1))

            if global_step % config.write_log_every_itr == 0:
                summary_loss_itr, summary_image = sess.run(
                    [trainer.scalar_sum_itr, trainer.image_sum], feed_dict)
                writer_scalar_itr.add_summary(summary_loss_itr, global_step)
                writer_image.add_summary(summary_image, global_step)

            print_logs('TRAIN',
                       mode,
                       epoch,
                       step_time,
                       idx,
                       data_loader.num_itr,
                       errs=errs,
                       coefs=trainer.coef_container,
                       lr=lr)
            global_step += 1
            idx += 1

        # SAVE LOGS
        errs_total_train = dict_operations(errs_total_train, '/',
                                           data_loader.num_itr)
        summary_loss_epoch = sess.run(trainer.summary_epoch,
                                      feed_dict=dict_operations(
                                          trainer.loss_epoch_placeholder, '=',
                                          errs_total_train))
        writer_scalar_epoch.add_summary(summary_loss_epoch, epoch)

        ## TEST
        errs_total_test = collections.OrderedDict.fromkeys(
            trainer.loss_test.keys(), 0.)
        epoch_time_test = time.time()
        idx = 0
        #while True:
        for idx in range(0, 2):
            step_time = time.time()

            feed_dict, is_end = data_loader_test.feed_the_network()
            if is_end: break
            feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
            errs = sess.run(trainer.loss_test, feed_dict)

            errs_total_test = dict_operations(errs_total_test, '+', errs)
            print_logs('TEST',
                       mode,
                       epoch,
                       step_time,
                       idx,
                       data_loader_test.num_itr,
                       errs=errs,
                       coefs=trainer.coef_container)
            idx += 1

        # SAVE LOGS
        errs_total_test = dict_operations(errs_total_test, '/',
                                          data_loader_test.num_itr)
        summary_loss_epoch_test = sess.run(trainer.summary_epoch,
                                           feed_dict=dict_operations(
                                               trainer.loss_epoch_placeholder,
                                               '=', errs_total_test))
        writer_scalar_epoch_valid.add_summary(summary_loss_epoch_test, epoch)

        ## CKPT
        if epoch % config.write_ckpt_every_epoch == 0:
            ckpt_manager.save_ckpt(sess,
                                   trainer.save_vars,
                                   epoch,
                                   score=errs_total_test['total'])

        ## RESET IMAGE SUMMARY
        if epoch % config.refresh_image_log_every_epoch == 0:
            writer_image.close()
            remove_file_end_with(config.LOG_DIR.log_image, '*.image_log')
            writer_image.reopen()

        print_logs('TRAIN SUMMARY',
                   mode,
                   epoch,
                   epoch_time,
                   errs=errs_total_train)
        print_logs('TEST SUMMARY',
                   mode,
                   epoch,
                   epoch_time_test,
                   errs=errs_total_test)
示例#7
0
def evaluate(config, mode):
    print(toGreen('Loading checkpoint manager...'))
    print(config.LOG_DIR.ckpt, mode)
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, mode, 10)

    date = datetime.datetime.now().strftime('%Y.%m.%d.(%H%M)')
    print(toYellow('======== EVALUATION START ========='))
    ##################################################
    ## DEFINE MODEL
    print(toGreen('Initializing model'))
    model = create_model(config)
    model.eval()
    model.print()

    inputs = {
        'inp': None,
        'ref': None,
        'inp_hist': None,
        'ref_hist': None,
        'seg_inp': None,
        'seg_ref': None
    }
    inputs = collections.OrderedDict(sorted(inputs.items(),
                                            key=lambda t: t[0]))

    ## INITIALIZING VARIABLE
    print(toGreen('Initializing variables'))
    result, ckpt_name = ckpt_manager.load_ckpt(
        model.get_network(),
        by_score=config.EVAL.load_ckpt_by_score,
        name=config.EVAL.ckpt_name)
    print(result)
    save_path_root = os.path.join(config.EVAL.LOG_DIR.save,
                                  config.EVAL.eval_mode, ckpt_name, date)
    exists_or_mkdir(save_path_root)
    torch.save(model.get_network().state_dict(),
               os.path.join(save_path_root, ckpt_name + '.pytorch'))

    ##################################################
    inp_folder_path_list, _, _ = load_file_list(config.EVAL.inp_path)
    ref_folder_path_list, _, _ = load_file_list(config.EVAL.ref_path)
    inp_segmap_folder_path_list, _, _ = load_file_list(
        config.EVAL.inp_segmap_path)
    ref_segmap_folder_path_list, _, _ = load_file_list(
        config.EVAL.ref_segmap_path)

    print(toGreen('Starting Color Trasfer'))
    itr = 0
    for folder_idx in np.arange(len(inp_folder_path_list)):
        inp_folder_path = inp_folder_path_list[folder_idx]
        ref_video_path = ref_folder_path_list[folder_idx]
        inp_segmap_folder_path = inp_segmap_folder_path_list[folder_idx]
        ref_segmap_video_path = ref_segmap_folder_path_list[folder_idx]

        _, inp_file_path_list, _ = load_file_list(inp_folder_path)
        _, ref_file_path_list, _ = load_file_list(ref_video_path)
        _, inp_segmap_file_path_list, _ = load_file_list(
            inp_segmap_folder_path)
        _, ref_segmap_file_path_list, _ = load_file_list(ref_segmap_video_path)
        for file_idx in np.arange(len(inp_file_path_list)):
            inp_path = inp_file_path_list[file_idx]
            ref_path = ref_file_path_list[file_idx]
            inp_segmap_path = inp_segmap_file_path_list[file_idx]
            ref_segmap_path = ref_segmap_file_path_list[file_idx]

            # inputs['inp'], inputs['inp_hist'] = torch.FloatTensor(_read_frame_cv(inp_path, config).transpose(0, 3, 1, 2)).cuda()

            inputs['inp'], inputs['inp_hist'] = _read_frame_cv(
                inp_path, config)
            inputs['ref'], inputs['ref_hist'] = _read_frame_cv(
                ref_path, config)
            inputs['seg_inp'], inputs['seg_ref'] = _read_segmap(
                inputs['inp'], inp_segmap_path, ref_segmap_path, config.is_rep)

            p = 30
            reppad = torch.nn.ReplicationPad2d(p)

            for key, val in inputs.items():
                inputs[key] = torch.FloatTensor(inputs[key].transpose(
                    0, 3, 1, 2)).cuda()

            inputs['inp'] = reppad(inputs['inp'])
            inputs['ref'] = reppad(inputs['ref'])
            inputs['seg_inp'] = reppad(inputs['seg_inp'])
            inputs['seg_ref'] = reppad(inputs['seg_ref'])

            with torch.no_grad():
                outs = model.get_results(inputs)

            inputs['inp'] = inputs['inp'][:, :, p:(inputs['inp'].size(2) - p),
                                          p:(inputs['inp'].size(3) - p)]
            inputs['ref'] = inputs['ref'][:, :, p:(inputs['ref'].size(2) - p),
                                          p:(inputs['ref'].size(3) - p)]
            outs['result'] = outs['result'][:, :,
                                            p:(outs['result'].size(2) - p),
                                            p:(outs['result'].size(3) - p)]
            outs['result_idt'] = outs['result_idt'][:, :, p:(
                outs['result_idt'].size(2) -
                p), p:(outs['result_idt'].size(3) - p)]
            outs['seg_inp'] = outs['seg_inp'][:, :,
                                              p:(outs['seg_inp'].size(2) - p),
                                              p:(outs['seg_inp'].size(3) - p)]
            outs['seg_ref'] = outs['seg_ref'][:, :,
                                              p:(outs['seg_ref'].size(2) - p),
                                              p:(outs['seg_ref'].size(3) - p)]

            file_name = os.path.basename(
                inp_file_path_list[file_idx]).split('.')[0]
            save_path = save_path_root
            exists_or_mkdir(save_path)

            vutils.save_image(LAB2RGB_cv(inputs['inp'].detach().cpu(),
                                         config.type),
                              '{}/{}_1_inp.png'.format(save_path, itr),
                              nrow=3,
                              padding=0,
                              normalize=False)
            vutils.save_image(LAB2RGB_cv(inputs['ref'].detach().cpu(),
                                         config.type),
                              '{}/{}_2_ref.png'.format(save_path, itr),
                              nrow=3,
                              padding=0,
                              normalize=False)
            i = 3
            out_keys = ['result', 'result_idt', 'seg_inp', 'seg_ref']
            for key, val in outs.items():
                if key in out_keys:
                    if val is not None:
                        if 'seg' in key:
                            if config.identity is False and 'idt' in key:
                                continue
                            vutils.save_image(
                                LAB2RGB_cv(val.detach().cpu(), 'rgb') / 8.,
                                '{}/{}_{}_out_{}.png'.format(
                                    save_path, itr, i, key),
                                nrow=3,
                                padding=0,
                                normalize=False)
                        else:
                            if config.identity is False and 'idt' in key:
                                continue
                            vutils.save_image(LAB2RGB_cv(
                                val.detach().cpu(), config.type),
                                              '{}/{}_{}_out_{}.png'.format(
                                                  save_path, itr, i, key),
                                              nrow=3,
                                              padding=0,
                                              normalize=False)
                        i += 1

            #PSNR
            print(
                '[{}][{}/{}'.format(ckpt_name, folder_idx + 1,
                                    len(inp_folder_path_list)),
                '{}/{}]'.format(file_idx + 1, len(inp_file_path_list)))
            itr += 1