Пример #1
0
def save_loss_for_matlab(trainLoss_dict, valLoss_dict, auxInfo=''):
    destPath = "./myDataAnalysis"
    destName = "AANet_trainLoss_{}".format(auxInfo)
    utils.check_path(destPath)
    saveDictAsMatlab(os.path.join(destPath, destName), trainLoss_dict)

    destName = "AANet_valLoss_{}".format(auxInfo)
    utils.check_path(destPath)
    saveDictAsMatlab(os.path.join(destPath, destName), valLoss_dict)
Пример #2
0
def generate_adj_data_from_grounded_concepts_ckb(cpnet_graph_path, kb,
                                                 keywords_path, output_path,
                                                 num_processes):
    """
    This function will save
        (1) adjacency matrics (each in the form of a (R*N, N) coo sparse matrix)
        (2) concepts ids
        (3) qmask that specifices whether a node is a question concept
        (4) amask that specifices whether a node is a answer concept
    to the output path in python pickle format

    grounded_path: str
    cpnet_graph_path: str
    cpnet_vocab_path: str
    output_path: str
    num_processes: int
    """
    print(f'generating adj data for {keywords_path}...')
    check_path(output_path)

    global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet
    if any(x is None
           for x in [concept2id, id2concept, relation2id, id2relation]):
        #load_resources(cpnet_vocab_path)
        concept2id = dict(kb.kb_vocab)
        id2concept = dict(kb.invert_kb_vocab)
        relation2id = dict(kb.relation2id)
        id2relation = dict(kb.id2relation)

    if cpnet is None or cpnet_simple is None:
        load_cpnet(cpnet_graph_path)

    keywords_dict = pickle.load(open(keywords_path, 'rb'))
    #format keywords_dict[idx] = {"cq_word":set(), "answerA":set(), "answerB":set(), "answerC":set() }

    qa_data = []
    for value in keywords_dict:
        for key, words in value.items():
            if key.startswith("answer"):
                mcp = {"qc": value['cq_word'], "ac": words}
                q_ids = set(concept2id[c] for c in mcp["qc"])
                a_ids = set(concept2id[c] for c in mcp["ac"])
                q_ids = q_ids - a_ids
                qa_data.append((q_ids, a_ids))

    with Pool(num_processes) as p:
        res = list(
            tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair, qa_data),
                 total=len(qa_data)))

    # res is a list of tuples, each tuple consists of four elements (adj, concepts, qmask, amask)
    with open(output_path, 'wb') as fout:
        pickle.dump(res, fout)

    print(f'adj data saved to {output_path}')
    print()
Пример #3
0
def csqa_keywords_load(kb, input_file, output_file):
    '''
    #Output format:
        {
            id: {"cq_word":set(), "answerA":set(), "answerB":set(), "answerC": set(), "all_ans_word": set()},
            ...
        }

    '''
    check_path(output_file)

    with open(input_file, 'r', encoding='utf-8') as f:

        lines = f.readlines()
        all_keywords_dict = [dict() for _ in range(len(lines))]
        for idx, line in enumerate(lines):
            json_dic = json.loads(line)

            answers_dict = {"answerA":json_dic["question"]["choices"][0]["text"], 
                "answerB":json_dic["question"]["choices"][1]["text"],
                "answerC":json_dic["question"]["choices"][2]["text"],
                "answerD":json_dic["question"]["choices"][3]["text"],
                "answerE":json_dic["question"]["choices"][4]["text"],
                }
            
            all_keywords_dict[idx] = {"cq_word":set(), "answerA":set(), "answerB":set(), "answerC":set(), "answerD":set(), "answerE":set() }

            all_keywords_dict[idx]["cq_word"] = set(('conceptnet',word) for word in kb.get_keywords_from_text(json_dic["question"]["stem"]) if word in kb.only_word_dict)
            for key, value in answers_dict.items():
                all_keywords_dict[idx][key] = set(('conceptnet',word) for word in kb.get_keywords_from_text(value) if word in kb.only_word_dict)
                if len(all_keywords_dict[idx]["cq_word"]) == 0 and len(all_keywords_dict[idx][key]) == 0:
                    question_concept = json_dic["question"]["question_concept"].replace(" ", "_")
                    all_keywords_dict[idx]["cq_word"] = {('conceptnet',question_concept)} if question_concept in kb.only_word_dict else set()
                    if len(all_keywords_dict[idx]["cq_word"]) == 0:
                        print("No keywords!", idx, all_keywords_dict)
                        exit() 
    pickle.dump(all_keywords_dict, open(output_file, 'wb'))

    return all_keywords_dict
Пример #4
0
parser.add_argument("--encode_layer", type=int, default=-1)

parser.add_argument("--seed",
                    type=int,
                    default=42,
                    help="random seed for initialization")
parser.add_argument("--no_cuda",
                    action="store_true",
                    help="Avoid using CUDA when available")
args = parser.parse_args()

args.device = torch.device(
    "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

check_path(args.output_gen_rel_jsonl)

set_seed(args)

# Initialize the model and tokenizer
model_class, tokenizer_class = GPT2LMHeadModel, GPT2Tokenizer

tokenizer = tokenizer_class.from_pretrained(args.generator_ckpt_folder)
model = model_class.from_pretrained(args.generator_ckpt_folder,
                                    output_hidden_states=True)
model.to(args.device)
model.eval()
feature_size = model.config.hidden_size
args.search_max_len = adjust_length_to_model(
    args.search_max_len,
    max_sequence_length=model.config.max_position_embeddings)
Пример #5
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])
    test_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                         dataset_name=args.dataset_name,
                                         mode=args.mode,
                                         save_filename=True,
                                         transform=test_transform)
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=True, drop_last=False)

    aanet = nets.AANet(args.max_disp,
                       num_downsample=args.num_downsample,
                       feature_type=args.feature_type,
                       no_feature_mdconv=args.no_feature_mdconv,
                       feature_pyramid=args.feature_pyramid,
                       feature_pyramid_network=args.feature_pyramid_network,
                       feature_similarity=args.feature_similarity,
                       aggregation_type=args.aggregation_type,
                       num_scales=args.num_scales,
                       num_fusions=args.num_fusions,
                       num_stage_blocks=args.num_stage_blocks,
                       num_deform_blocks=args.num_deform_blocks,
                       no_intermediate_supervision=args.no_intermediate_supervision,
                       refinement_type=args.refinement_type,
                       mdconv_dilation=args.mdconv_dilation,
                       deformable_groups=args.deformable_groups).to(device)

    # print(aanet)

    if os.path.exists(args.pretrained_aanet):
        print('=> Loading pretrained AANet:', args.pretrained_aanet)
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=True)
    else:
        print('=> Using random initialization')

    # Save parameters
    num_params = utils.count_parameters(aanet)
    print('=> Number of trainable parameters: %d' % num_params)

    if torch.cuda.device_count() > 1:
        print('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Inference
    aanet.eval()

    inference_time = 0
    num_imgs = 0

    num_samples = len(test_loader)
    print('=> %d samples found in the test set' % num_samples)

    for i, sample in enumerate(test_loader):
        if args.count_time and i == args.num_images:  # testing time only
            break

        if i % 100 == 0:
            print('=> Inferencing %d/%d' % (i, num_samples))

        left = sample['left'].to(device)  # [B, 3, H, W]
        right = sample['right'].to(device)

        # Pad
        ori_height, ori_width = left.size()[2:]
        if ori_height < args.img_height or ori_width < args.img_width:
            top_pad = args.img_height - ori_height
            right_pad = args.img_width - ori_width

            # Pad size: (left_pad, right_pad, top_pad, bottom_pad)
            left = F.pad(left, (0, right_pad, top_pad, 0))
            right = F.pad(right, (0, right_pad, top_pad, 0))

        # Warmup
        if i == 0 and args.count_time:
            with torch.no_grad():
                for _ in range(10):
                    aanet(left, right)

        num_imgs += left.size(0)

        with torch.no_grad():
            time_start = time.perf_counter()
            pred_disp = aanet(left, right)[-1]  # [B, H, W]
            inference_time += time.perf_counter() - time_start

        if pred_disp.size(-1) < left.size(-1):
            pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
            pred_disp = F.interpolate(pred_disp, (left.size(-2), left.size(-1)),
                                      mode='bilinear', align_corners=True, recompute_scale_factor=True) * (left.size(-1) / pred_disp.size(-1))
            pred_disp = pred_disp.squeeze(1)  # [B, H, W]

        # Crop
        if ori_height < args.img_height or ori_width < args.img_width:
            if right_pad != 0:
                pred_disp = pred_disp[:, top_pad:, :-right_pad]
            else:
                pred_disp = pred_disp[:, top_pad:]

        for b in range(pred_disp.size(0)):
            disp = pred_disp[b].detach().cpu().numpy()  # [H, W]
            save_name = sample['left_name'][b]
            save_name = os.path.join(args.output_dir, save_name)
            utils.check_path(os.path.dirname(save_name))
            if not args.count_time:
                if args.save_type == 'pfm':
                    if args.visualize:
                        skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

                    save_name = save_name[:-3] + 'pfm'
                    write_pfm(save_name, disp)
                elif args.save_type == 'npy':
                    save_name = save_name[:-3] + 'npy'
                    np.save(save_name, disp)
                else:
                    skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

    print('=> Mean inference time for %d images: %.3fs' % (num_imgs, inference_time / num_imgs))
Пример #6
0
parser.add_argument('--pretrained_aanet', default=None, type=str, help='Pretrained network')

parser.add_argument('--save_type', default='png', choices=['pfm', 'png', 'npy'], help='Save file type')
parser.add_argument('--visualize', action='store_true', help='Visualize disparity map')

# Log
parser.add_argument('--count_time', action='store_true', help='Inference on a subset for time counting only')
parser.add_argument('--num_images', default=100, type=int, help='Number of images for inference')

args = parser.parse_args()

model_name = os.path.basename(args.pretrained_aanet)[:-4]
model_dir = os.path.basename(os.path.dirname(args.pretrained_aanet))
args.output_dir = os.path.join(args.output_dir, model_dir + '-' + model_name)

utils.check_path(args.output_dir)
utils.save_command(args.output_dir)


def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
Пример #7
0
# Log
parser.add_argument('--print_freq', default=1, type=int, help='Print frequency to screen')
parser.add_argument('--summary_freq', default=100, type=int, help='Summary frequency to tensorboard')
parser.add_argument('--no_build_summary', action='store_true', help='Dont save sammary when training to save space')
parser.add_argument('--save_ckpt_freq', default=10, type=int, help='Save checkpoint frequency')

parser.add_argument('--evaluate_only', action='store_true', help='Evaluate pretrained models')
parser.add_argument('--no_validate', action='store_true', help='No validation')
parser.add_argument('--strict', action='store_true', help='Strict mode when loading checkpoints')
parser.add_argument('--val_metric', default='epe', help='Validation metric to select best model')

args = parser.parse_args()
logger = utils.get_logger()

utils.check_path(args.checkpoint_dir)
utils.save_args(args)

filename = 'command_test.txt' if args.mode == 'test' else 'command_train.txt'
utils.save_command(args.checkpoint_dir, filename)


def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Пример #8
0
    def train(self, train_loader):
        args = self.args
        logger = self.logger

        steps_per_epoch = len(train_loader)
        device = self.device

        self.aanet.train()

        if args.freeze_bn:

            def set_bn_eval(m):
                classname = m.__class__.__name__
                if classname.find('BatchNorm') != -1:
                    m.eval()

            self.aanet.apply(set_bn_eval)

        # Learning rate summary
        base_lr = self.optimizer.param_groups[0]['lr']
        offset_lr = self.optimizer.param_groups[1]['lr']
        self.train_writer.add_scalar('base_lr', base_lr, self.epoch + 1)
        self.train_writer.add_scalar('offset_lr', offset_lr, self.epoch + 1)

        last_print_time = time.time()

        for i, sample in enumerate(train_loader):
            left = sample['left'].to(device)  # [B, 3, H, W]
            right = sample['right'].to(device)
            gt_disp = sample['disp'].to(device)  # [B, H, W]

            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if args.load_pseudo_gt:
                pseudo_gt_disp = sample['pseudo_disp'].to(device)
                pseudo_mask = (pseudo_gt_disp > 0) & (
                    pseudo_gt_disp < args.max_disp) & (~mask)  # inverse mask

            if not mask.any():
                continue

            pred_disp_pyramid = self.aanet(
                left, right)  # list of H/12, H/6, H/3, H/2, H

            if args.highest_loss_only:
                pred_disp_pyramid = [
                    pred_disp_pyramid[-1]
                ]  # only the last highest resolution output

            disp_loss = 0
            pseudo_disp_loss = 0
            pyramid_loss = []
            pseudo_pyramid_loss = []

            # Loss weights
            if len(pred_disp_pyramid) == 5:
                pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0, 1.0]  # AANet
            elif len(pred_disp_pyramid) == 4:
                pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0]  # AANet+
            elif len(pred_disp_pyramid) == 3:
                pyramid_weight = [1.0, 1.0, 1.0]  # 1 scale only
            elif len(pred_disp_pyramid) == 1:
                pyramid_weight = [1.0]  # highest loss only
            else:
                raise NotImplementedError

            assert len(pyramid_weight) == len(pred_disp_pyramid)
            for k in range(len(pred_disp_pyramid)):
                pred_disp = pred_disp_pyramid[k]
                weight = pyramid_weight[k]

                if pred_disp.size(-1) != gt_disp.size(-1):
                    pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                    pred_disp = F.interpolate(
                        pred_disp,
                        size=(gt_disp.size(-2), gt_disp.size(-1)),
                        mode='bilinear') * (gt_disp.size(-1) /
                                            pred_disp.size(-1))
                    pred_disp = pred_disp.squeeze(1)  # [B, H, W]

                curr_loss = F.smooth_l1_loss(pred_disp[mask],
                                             gt_disp[mask],
                                             reduction='mean')
                disp_loss += weight * curr_loss
                pyramid_loss.append(curr_loss)

                # Pseudo gt loss
                if args.load_pseudo_gt:
                    pseudo_curr_loss = F.smooth_l1_loss(
                        pred_disp[pseudo_mask],
                        pseudo_gt_disp[pseudo_mask],
                        reduction='mean')
                    pseudo_disp_loss += weight * pseudo_curr_loss

                    pseudo_pyramid_loss.append(pseudo_curr_loss)

            total_loss = disp_loss + pseudo_disp_loss

            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            self.num_iter += 1

            if self.num_iter % args.print_freq == 0:
                this_cycle = time.time() - last_print_time
                last_print_time += this_cycle

                logger.info(
                    'Epoch: [%3d/%3d] [%5d/%5d] time: %4.2fs disp_loss: %.3f' %
                    (self.epoch + 1, args.max_epoch, i + 1, steps_per_epoch,
                     this_cycle, disp_loss.item()))

            if self.num_iter % args.summary_freq == 0:
                img_summary = dict()
                img_summary['left'] = left
                img_summary['right'] = right
                img_summary['gt_disp'] = gt_disp

                if args.load_pseudo_gt:
                    img_summary['pseudo_gt_disp'] = pseudo_gt_disp

                # Save pyramid disparity prediction
                for s in range(len(pred_disp_pyramid)):
                    # Scale from low to high, reverse
                    save_name = 'pred_disp' + str(
                        len(pred_disp_pyramid) - s - 1)
                    save_value = pred_disp_pyramid[s]
                    img_summary[save_name] = save_value

                pred_disp = pred_disp_pyramid[-1]

                if pred_disp.size(-1) != gt_disp.size(-1):
                    pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                    pred_disp = F.interpolate(
                        pred_disp,
                        size=(gt_disp.size(-2), gt_disp.size(-1)),
                        mode='bilinear') * (gt_disp.size(-1) /
                                            pred_disp.size(-1))
                    pred_disp = pred_disp.squeeze(1)  # [B, H, W]
                img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)

                save_images(self.train_writer, 'train', img_summary,
                            self.num_iter)

                epe = F.l1_loss(gt_disp[mask],
                                pred_disp[mask],
                                reduction='mean')

                self.train_writer.add_scalar('train/epe', epe.item(),
                                             self.num_iter)
                self.train_writer.add_scalar('train/disp_loss',
                                             disp_loss.item(), self.num_iter)
                self.train_writer.add_scalar('train/total_loss',
                                             total_loss.item(), self.num_iter)

                # Save loss of different scale
                for s in range(len(pyramid_loss)):
                    save_name = 'train/loss' + str(len(pyramid_loss) - s - 1)
                    save_value = pyramid_loss[s]
                    self.train_writer.add_scalar(save_name, save_value,
                                                 self.num_iter)

                d1 = d1_metric(pred_disp, gt_disp, mask)
                self.train_writer.add_scalar('train/d1', d1.item(),
                                             self.num_iter)
                thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
                thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
                thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)
                self.train_writer.add_scalar('train/thres1', thres1.item(),
                                             self.num_iter)
                self.train_writer.add_scalar('train/thres2', thres2.item(),
                                             self.num_iter)
                self.train_writer.add_scalar('train/thres3', thres3.item(),
                                             self.num_iter)

        self.epoch += 1

        # Always save the latest model for resuming training
        if args.no_validate:
            utils.save_checkpoint(args.checkpoint_dir,
                                  self.optimizer,
                                  self.aanet,
                                  epoch=self.epoch,
                                  num_iter=self.num_iter,
                                  epe=-1,
                                  best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth')

            # Save checkpoint of specific epoch
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir,
                                      self.optimizer,
                                      self.aanet,
                                      epoch=self.epoch,
                                      num_iter=self.num_iter,
                                      epe=-1,
                                      best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False)
Пример #9
0
    def validate(self, val_loader):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')

        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir,
                                                model_name)
                if not os.path.exists(
                        pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(
                        model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet,
                                      pretrained_aanet,
                                      no_strict=True)

        self.aanet.eval()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        val_thres1 = 0
        val_thres2 = 0
        val_thres3 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        for i, sample in enumerate(val_loader):
            if i % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if not mask.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                pred_disp = self.aanet(left, right)[-1]  # [B, H, W]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(
                    pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                    mode='bilinear') * (gt_disp.size(-1) / pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask)
            thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_thres1 += thres1.item()
            val_thres2 += thres2.item()
            val_thres3 += thres3.item()

            # Save 3 images for visualization
            if not args.evaluate_only:
                if i in [
                        num_samples // 4, num_samples // 2,
                        num_samples // 4 * 3
                ]:
                    img_summary = dict()
                    img_summary['disp_error'] = disp_error_img(
                        pred_disp, gt_disp)
                    img_summary['left'] = left
                    img_summary['right'] = right
                    img_summary['gt_disp'] = gt_disp
                    img_summary['pred_disp'] = pred_disp
                    save_images(self.train_writer, 'val' + str(val_count),
                                img_summary, self.epoch)
                    val_count += 1

        logger.info('=> Validation done!')

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_thres1 = val_thres1 / valid_samples
        mean_thres2 = val_thres2 / valid_samples
        mean_thres3 = val_thres3 / valid_samples

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.3f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('thres1: %.4f\t' % mean_thres1)
            f.write('thres2: %.4f\t' % mean_thres2)
            f.write('thres3: %.4f\n' % mean_thres3)

        logger.info('=> Mean validation epe of epoch %d: %.3f' %
                    (self.epoch, mean_epe))

        if not args.evaluate_only:
            self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
            self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
            self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
            self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
            self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_d1,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_epe,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir,
                                  self.optimizer,
                                  self.aanet,
                                  epoch=self.epoch,
                                  num_iter=self.num_iter,
                                  epe=mean_epe,
                                  best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth')

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir,
                                      self.optimizer,
                                      self.aanet,
                                      epoch=self.epoch,
                                      num_iter=self.num_iter,
                                      epe=mean_epe,
                                      best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False)
Пример #10
0
Файл: hgn.py Проект: INK-USC/HGN
def main():
    parser = get_parser()
    args, _ = parser.parse_known_args()
    parser.add_argument('--mode', default='train', choices=['train', 'pred'], help='run training or evaluation')
    parser.add_argument('--save_dir', required=True, help='model output directory')
    parser.add_argument('--save_file_name', default='')
    parser.add_argument('--save_model', default=True, type=bool_flag)

    # statements
    parser.add_argument('--train_jsonl', required=True)
    parser.add_argument('--dev_jsonl', required=True)
    parser.add_argument('--test_jsonl')

    # data
    parser.add_argument('--num_choice', type=int, required=True, help='how many choices for each question')

    parser.add_argument('--train_adj_pk', required=True)
    parser.add_argument('--train_gen_pt', required=True)

    parser.add_argument('--dev_adj_pk', required=True)
    parser.add_argument('--dev_gen_pt', required=True)

    parser.add_argument('--test_adj_pk')
    parser.add_argument('--test_gen_pt')

    # pred mode
    parser.add_argument('--test_path_base')
    parser.add_argument('--test_model_path')
    parser.add_argument('--output_pred_path')
    parser.add_argument('--output_graph', default=False, type=bool_flag)

    # model architecture
    parser.add_argument('--ablation', default=[], nargs='+', choices=['GAT', 'no_edge_weight', 'extraction_only', 'unnormalized_edge_weight', 'wo_statement_vec'])
    # no_edge_weight = no learnable edge weight in message passing (all weights = 1) + no sparsity loss
    parser.add_argument('--att_head_num', default=2, type=int, help='number of attention heads')
    parser.add_argument('--mlp_dim', default=128, type=int, help='number of MLP hidden units')
    parser.add_argument('--fc_dim', default=128, type=int, help='number of FC hidden units')
    parser.add_argument('--fc_layer_num', default=0, type=int, help='number of FC layers')
    parser.add_argument('--freeze_ent_emb', default=True, type=bool_flag, nargs='?', const=True, help='freeze entity embedding layer')
    parser.add_argument('--emb_scale', default=1.0, type=float, help='scale pretrained embeddings')
    parser.add_argument('--num_gnn_layers', default=1, type=int, help='scale pretrained embeddings')
    # regularization
    parser.add_argument('--dropoutm', type=float, default=0.3, help='dropout for mlp hidden units (0 = no dropout')

    # optimization
    parser.add_argument('-dlr', '--decoder_lr', default=3e-4, type=float, help='learning rate')
    parser.add_argument('-mbs', '--mini_batch_size', default=1, type=int)  # batch size should be divisible by mini batch size in current implementation.
    parser.add_argument('-ebs', '--eval_batch_size', default=-1, type=int)
    parser.add_argument('--unfreeze_epoch', default=0, type=int)
    parser.add_argument('--refreeze_epoch', default=10000, type=int)
    parser.add_argument('--eval_interval', default=0, type=int, help='steps_per_eval (0 = eval after each epoch)')

    # specific to Hybrid GN
    parser.add_argument('--alpha', default=0, type=float, help='weight for binary loss')
    parser.add_argument('--edge_weight_dropout', default=0.2, type=float)

    # CODAH
    parser.add_argument('--warmup_ratio', default=None, type=float)
    parser.add_argument('--use_last_epoch', default=False, type=bool_flag)
    parser.add_argument('--fold', default=None, type=str)

    args = parser.parse_args()
    if args.test_path_base is not None:
        args.test_model_path = args.test_path_base + args.test_model_path
        args.output_pred_path = args.test_path_base + args.output_pred_path
    if True:  # args.eval_batch_size == -1:
        args.eval_batch_size = args.mini_batch_size  # should be the same due to test data loader
    if args.mini_batch_size > args.batch_size:
        args.mini_batch_size = args.batch_size
    if args.batch_size % args.mini_batch_size != 0:
        raise ValueError('batch size should be divisible by mini batch size')
    if args.fold is not None:
        args.train_jsonl = args.train_jsonl.replace('{fold}', args.fold)
        args.dev_jsonl = args.dev_jsonl.replace('{fold}', args.fold)
        args.test_jsonl = args.test_jsonl.replace('{fold}', args.fold)
        args.train_adj_pk = args.train_adj_pk.replace('{fold}', args.fold)
        args.dev_adj_pk = args.dev_adj_pk.replace('{fold}', args.fold)
        args.test_adj_pk = args.test_adj_pk.replace('{fold}', args.fold)
        args.train_gen_pt = args.train_gen_pt.replace('{fold}', args.fold)
        args.dev_gen_pt = args.dev_gen_pt.replace('{fold}', args.fold)
        args.test_gen_pt = args.test_gen_pt.replace('{fold}', args.fold)
        args.save_dir = args.save_dir.replace('{fold}', args.fold)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    unique_str = datetime.now().strftime("%m%d_%H%M%S.%f") + args.save_file_name
    log_name = unique_str + '.log'
    log_path = os.path.join(args.save_dir, log_name)
    check_path(log_path)
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s %(message)s",
        handlers=[
            logging.FileHandler(log_path),
            logging.StreamHandler()
        ]
    )
    args.save_file_name = unique_str + '.pt'
    if args.mode == 'train':
        dev_acc, test_acc, best_test_acc = train(args)
        new_file_str = f'dlr{args.decoder_lr}_{dev_acc * 100:.2f}_{test_acc * 100:.2f}_s{args.seed}_{best_test_acc * 100:.2f}_{unique_str}'
        os.rename(log_path, os.path.join(args.save_dir, new_file_str + '.log'))
        if args.save_model:
            os.rename(os.path.join(args.save_dir, unique_str + '.pt'), os.path.join(args.save_dir, new_file_str + '.pt'))
    elif args.mode == 'eval':
        eval(args)
    elif args.mode == 'pred':
        pred(args)
    else:
        raise ValueError('Invalid mode')
Пример #11
0
Файл: hgn.py Проект: INK-USC/HGN
def train(args):
    logging.info(f'{socket.gethostname()}: {os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "unknown"}')
    logging.info('python ' + ' '.join(sys.argv))
    logging.info(args)

    model_path = os.path.join(args.save_dir, args.save_file_name)
    check_path(model_path)

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))
    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)

    rel_emb = np.load(args.rel_emb_path)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = torch.tensor(rel_emb)
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    logging.info('| num_concepts: {} | num_relations: {} |'.format(concept_num, relation_num))

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    lm_data_loader = LMDataLoader(args.train_jsonl, args.dev_jsonl, args.test_jsonl,
                                  batch_size=args.mini_batch_size, eval_batch_size=args.eval_batch_size, device=device,
                                  model_name=args.encoder, max_seq_length=args.max_seq_len, is_inhouse=args.inhouse,
                                  inhouse_train_qids_path=args.inhouse_train_qids, subset_qids_path=args.subset_train_qids,
                                  format=args.format)
    logging.info(f'| # train questions: {lm_data_loader.train_size()} | # dev questions: {lm_data_loader.dev_size()} | # test questions: {lm_data_loader.test_size()} |')

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################
    graph_data_loader = GraphDataLoader(args.train_adj_pk, args.train_gen_pt, args.dev_adj_pk, args.dev_gen_pt,
                                        args.test_adj_pk, args.test_gen_pt,
                                        args.mini_batch_size, args.eval_batch_size, args.num_choice, args.ablation)
    train_avg_node_num, train_avg_edge_num = graph_data_loader.get_pyg_loader(lm_data_loader.get_train_indexes(), stats_only=True)

    dev_lm_data_loader = lm_data_loader.dev()
    dev_graph_loader, dev_avg_node_num, dev_avg_edge_num = graph_data_loader.dev_graph_data()
    assert len(dev_graph_loader) == len(dev_lm_data_loader)

    if args.inhouse:
        test_index = lm_data_loader.get_test_indexes()
        test_graph_loader, test_avg_node_num, test_avg_edge_num = graph_data_loader.get_pyg_loader(test_index)
    else:
        test_index = None
        test_graph_loader, test_avg_node_num, test_avg_edge_num = graph_data_loader.test_graph_data()
    test_lm_data_loader = lm_data_loader.test(test_index)
    assert len(test_graph_loader) == len(test_lm_data_loader)

    logging.info(f'| train | avg node num: {train_avg_node_num:.2f} | avg edge num: {train_avg_edge_num:.2f} |')
    logging.info(f'| dev   | avg node num: {dev_avg_node_num:.2f} | avg edge num: {dev_avg_edge_num:.2f} |')
    logging.info(f'| test  | avg node num: {test_avg_node_num:.2f} | avg edge num: {test_avg_edge_num:.2f} |')

    model = LMGraphNet(model_name=args.encoder, encoder_pooler=args.encoder_pooler,
                       concept_num=concept_num, concept_dim=relation_dim,
                       relation_num=relation_num, relation_dim=relation_dim, concept_in_dim=concept_dim,
                       hidden_size=args.mlp_dim, num_attention_heads=args.att_head_num,
                       fc_size=args.fc_dim, num_fc_layers=args.fc_layer_num, dropout=args.dropoutm,
                       edge_weight_dropout=args.edge_weight_dropout,
                       pretrained_concept_emb=cp_emb,  pretrained_relation_emb=rel_emb,
                       freeze_ent_emb=args.freeze_ent_emb, num_layers=args.num_gnn_layers,
                       ablation=args.ablation, emb_scale=args.emb_scale,
                       aristo_path=args.aristo_path)

    model.to(device)

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (lm_data_loader.train_size() / args.batch_size))
        if args.warmup_ratio is not None:
            warmup_steps = int(args.warmup_ratio * max_steps)
        else:
            warmup_steps = args.warmup_steps
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps)

    logging.info('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            logging.info('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            logging.info('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    logging.info(f'\ttotal: {num_params}')

    loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    logging.info('')
    logging.info('-' * 71)
    global_step, eval_id, best_dev_id, best_dev_step = 0, 0, 0, 0
    best_dev_acc, final_test_acc, best_test_acc, total_loss = 0.0, 0.0, 0.0, 0.0
    best_test_acc = 0.0
    exit_training = False
    train_start_time = time.time()
    start_time = train_start_time
    model.train()
    freeze_net(model.encoder)
    try:
        binary_score_lst = []
        for epoch_id in range(args.n_epochs):
            if exit_training:
                break
            if epoch_id == args.unfreeze_epoch:
                logging.info('encoder unfreezed')
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                logging.info('encoder refreezed')
                freeze_net(model.encoder)
            model.train()
            i = 0
            optimizer.zero_grad()
            train_index = lm_data_loader.get_train_indexes()
            train_graph_loader, train_avg_node_num, train_avg_edge_num = graph_data_loader.get_pyg_loader(train_index)
            train_lm_data_loader = lm_data_loader.train(train_index)
            assert len(train_graph_loader) == len(train_lm_data_loader)
            for graph, (qids, labels, *lm_input_data) in zip(train_graph_loader, train_lm_data_loader):
                graph = graph.to(device)
                edge_index = graph.edge_index
                row, col = edge_index
                node_batch = graph.batch
                num_of_nodes = graph.num_of_nodes
                num_of_edges = graph.num_of_edges
                rel_ids_embs = graph.edge_attr
                c_ids = graph.x
                c_types = graph.node_type
                logits, unnormalized_wts, normalized_wts = model(*lm_input_data, edge_index=edge_index, c_ids=c_ids, c_types=c_types, node_batch=node_batch, rel_ids_embs=rel_ids_embs, num_of_nodes=num_of_nodes, num_of_edges=num_of_edges)
                loss = loss_func(logits, labels)  # scale: loss per question
                if 'no_edge_weight' not in args.ablation and 'GAT' not in args.ablation:  # add options for other kinds of sparsity
                    log_wts = torch.log(normalized_wts + 0.0000001)
                    entropy = - normalized_wts * log_wts  # entropy: [num_of_edges in the batched graph, 1]
                    entropy = scatter_mean(entropy, node_batch[row], dim=0, dim_size=args.mini_batch_size * args.num_choice)
                    loss += args.alpha * torch.mean(entropy)  # scale: entropy per graph (each question has num_choice graphs)
                loss = loss * args.mini_batch_size / args.batch_size  # will be accumulated for (args.batch_size / args.mini_batch_size) times
                loss.backward()
                total_loss += loss.item()
                if 'no_edge_weight' not in args.ablation and 'GAT' not in args.ablation:
                    binary_score_lst += entropy.squeeze().tolist()
                else:
                    binary_score_lst.append(0)
                i = i + args.mini_batch_size
                if i % args.batch_size == 0:
                    if args.max_grad_norm > 0:
                        nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    optimizer.step()  # bp: scale: loss per question
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1
                    if global_step % args.log_interval == 0:
                        total_loss /= args.log_interval
                        ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                        logging.info('| step {:5} | lr: {:9.7f} | loss {:7.20f} | entropy score {:7.4f} | ms/batch {:7.2f} |'
                                     .format(global_step, scheduler.get_lr()[0], total_loss, np.mean(binary_score_lst), ms_per_batch))
                        total_loss = 0
                        binary_score_lst = []
                        start_time = time.time()
                    if args.eval_interval > 0:
                        if global_step % args.eval_interval == 0:
                            eval_id += 1
                            model.eval()
                            dev_acc = evaluate_accuracy(dev_graph_loader, dev_lm_data_loader, model, device)
                            test_acc = evaluate_accuracy(test_graph_loader, test_lm_data_loader, model, device)
                            # test_acc = 0.2
                            best_test_acc = max(best_test_acc, test_acc)
                            logging.info('-' * 71)
                            logging.info('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(global_step, dev_acc, test_acc))
                            logging.info('-' * 71)
                            if dev_acc >= best_dev_acc:
                                best_dev_acc = dev_acc
                                final_test_acc = test_acc
                                best_dev_id = eval_id
                                best_dev_step = global_step
                                if args.save_model:
                                    torch.save(model.state_dict(), model_path)
                                    copyfile(model_path, f'{model_path}_{global_step}_{dev_acc*100:.2f}_{test_acc*100:.2f}.pt')  # tmp
                                logging.info(f'model saved to {model_path}')
                            else:
                                logging.info(f'hit patience {eval_id - best_dev_id}/{args.patience}')
                            model.train()
                            if epoch_id > args.unfreeze_epoch and eval_id - best_dev_id >= args.patience:
                                exit_training = True
                                break
            if args.eval_interval == 0:
                eval_id += 1
                model.eval()
                dev_acc = evaluate_accuracy(dev_graph_loader, dev_lm_data_loader, model, device)
                test_acc = evaluate_accuracy(test_graph_loader, test_lm_data_loader, model, device)
                best_test_acc = max(best_test_acc, test_acc)
                logging.info('-' * 71)
                logging.info('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, dev_acc, test_acc))
                logging.info('-' * 71)
                if dev_acc >= best_dev_acc:
                    best_dev_acc = dev_acc
                    final_test_acc = test_acc
                    best_dev_id = eval_id
                    best_dev_step = global_step
                    if args.save_model:
                        torch.save(model.state_dict(), model_path)
                    logging.info(f'model saved to {model_path}')
                else:
                    logging.info(f'hit patience {eval_id - best_dev_id}/{args.patience}')
                model.train()
                if epoch_id > args.unfreeze_epoch and eval_id - best_dev_id >= args.patience:
                    exit_training = True
                    break
            start_time = time.time()
    except KeyboardInterrupt:
        logging.info('-' * 89)
        logging.info('Exiting from training early')
    train_end_time = time.time()
    logging.info('')
    logging.info(f'training ends in {global_step} steps, {train_end_time - train_start_time:.0f} s')
    logging.info('best dev acc: {:.4f} (at step {})'.format(best_dev_acc, best_dev_step))
    logging.info('final test acc: {:.4f}'.format(final_test_acc))
    if args.use_last_epoch:
        logging.info(f'last dev acc: {dev_acc:.4f}')
        logging.info(f'last test acc: {test_acc:.4f}')
        return dev_acc, test_acc, best_test_acc
    else:
        return best_dev_acc, final_test_acc, best_test_acc
Пример #12
0
    def train(self, train_loader, local_master, trainLoss_dict, trainLossKey):
        args = self.args
        logger = self.logger

        steps_per_epoch = len(train_loader) / args.accumulation_steps  # len(train_loader)返回的是Batch的个数
        device = self.device

        self.aanet.train()  # 设置模型为训练模式!

        if args.freeze_bn:
            def set_bn_eval(m):
                classname = m.__class__.__name__  # 实例调用__class__属性时会指向该实例对应的类。.__class__将实例变量指向类,然后再去调用__name__类属性
                if classname.find('BatchNorm') != -1:
                    m.eval()

            self.aanet.apply(
                set_bn_eval)  # apply(fn: Callable[Module, None]):Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).

        # Learning rate summary
        base_lr = self.optimizer.param_groups[0]['lr']
        offset_lr = self.optimizer.param_groups[1]['lr']
        self.train_writer.add_scalar('lr/base_lr', base_lr, self.epoch + 1)
        self.train_writer.add_scalar('lr/offset_lr', offset_lr, self.epoch + 1)

        last_print_time = time.time()

        validate_count = 0
        total_epe = 0
        total_d1 = 0
        total_thres1 = 0
        total_thres2 = 0
        total_thres3 = 0
        total_thres10 = 0
        total_thres20 = 0
        loss_acum = 0
        for i, sample in enumerate(train_loader):
            left = sample['left'].to(device)  # [B, 3, H, W]
            right = sample['right'].to(device)
            gt_disp = sample['disp'].to(device)  # [B, H, W]

            mask = (gt_disp > 0) & (gt_disp < args.max_disp)  # KITTI数据集约定:视差为0,表示无效视差。

            if args.load_pseudo_gt:
                pseudo_gt_disp = sample['pseudo_disp'].to(device)
                pseudo_mask = (pseudo_gt_disp > 0) & (pseudo_gt_disp < args.max_disp) & (
                    ~mask)  # inverse mask # 需要修补的像素位置的mask

            if not mask.any():  # np.array.any()是或操作,任意一个元素为True,输出为True。
                continue

            # 尝试分布式训练
            # 只在DDP模式下,轮数不是args.accumulation_steps整数倍的时候使用no_sync。
            # 博客:https://blog.csdn.net/a40850273/article/details/111829836
            my_context = self.aanet.no_sync if args.distributed and (
                        i + 1) % args.accumulation_steps != 0 else nullcontext
            with my_context():
                pred_disp_pyramid = self.aanet(left, right)  # list of H/12, H/6, H/3, H/2, H

                if args.highest_loss_only:
                    pred_disp_pyramid = [pred_disp_pyramid[-1]]  # only the last highest resolution output

                disp_loss = 0
                pseudo_disp_loss = 0
                pyramid_loss = []
                pseudo_pyramid_loss = []

                # Loss weights
                if len(pred_disp_pyramid) == 5:
                    pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0, 1.0]  # AANet and AANet+
                elif len(pred_disp_pyramid) == 4:
                    pyramid_weight = [1 / 3, 2 / 3, 1.0, 1.0]
                elif len(pred_disp_pyramid) == 3:
                    pyramid_weight = [1.0, 1.0, 1.0]  # 1 scale only
                elif len(pred_disp_pyramid) == 1:
                    pyramid_weight = [1.0]  # highest loss only
                else:
                    raise NotImplementedError

                assert len(pyramid_weight) == len(pred_disp_pyramid)
                for k in range(len(pred_disp_pyramid)):
                    pred_disp = pred_disp_pyramid[k]
                    weight = pyramid_weight[k]

                    if pred_disp.size(-1) != gt_disp.size(-1):
                        pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                        pred_disp = F.interpolate(pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)),
                                                  mode='bilinear', align_corners=False) * (
                                                gt_disp.size(-1) / pred_disp.size(-1))  # 最后乘上这一项是必须的。因为图像放大,视差要相应增大。
                        pred_disp = pred_disp.squeeze(1)  # [B, H, W]

                    curr_loss = F.smooth_l1_loss(pred_disp[mask], gt_disp[mask],
                                                 reduction='mean')
                    disp_loss += weight * curr_loss
                    pyramid_loss.append(curr_loss)

                    # Pseudo gt loss
                    if args.load_pseudo_gt:
                        pseudo_curr_loss = F.smooth_l1_loss(pred_disp[pseudo_mask], pseudo_gt_disp[pseudo_mask],
                                                            reduction='mean')
                        pseudo_disp_loss += weight * pseudo_curr_loss

                        pseudo_pyramid_loss.append(pseudo_curr_loss)

                total_loss = disp_loss + pseudo_disp_loss

                total_loss /= args.accumulation_steps
                total_loss.backward()

                # 仅用于记录和分析数据
                with torch.no_grad():
                    validate_count += 1
                    total_epe += F.l1_loss(gt_disp[mask], pred_disp_pyramid[-1][mask], reduction='mean').detach().cpu().numpy()
                    total_d1 += d1_metric(pred_disp_pyramid[-1], gt_disp, mask).detach().cpu().numpy()
                    total_thres1 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 1.0).detach().cpu().numpy()  # mask.shape=[B, H, W]
                    total_thres2 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 2.0).detach().cpu().numpy()
                    total_thres3 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 3.0).detach().cpu().numpy()
                    total_thres10 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 10.0).detach().cpu().numpy()
                    total_thres20 += thres_metric(pred_disp_pyramid[-1], gt_disp, mask, 20.0).detach().cpu().numpy()
                    loss_acum += total_loss.detach().cpu().numpy()

            if (i + 1) % args.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

                self.num_iter += 1

                if self.num_iter % args.print_freq == 0:
                    this_cycle = time.time() - last_print_time
                    last_print_time += this_cycle

                    time_to_finish = (args.max_epoch - self.epoch) * (1.0 * steps_per_epoch / args.print_freq) * \
                                     this_cycle / 3600.0  # 还有多久才能完成训练。单位:小时

                    logger.info('Epoch: [%3d/%3d] [%5d/%5d] time: %4.2fs remainT: %4.2fh disp_loss: %.3f' %
                                (self.epoch + 1, args.max_epoch, self.num_iter, steps_per_epoch, this_cycle,
                                 time_to_finish,
                                 disp_loss.item()))
                    # self.num_iter:表示当前一共进行了多少次迭代,一次参数更新表示一次迭代。
                    # steps_per_epoch:表示一个epoch中有多少次迭代,一次参数更新表示一次迭代。

                if self.num_iter % args.summary_freq == 0:
                    img_summary = dict()
                    img_summary['left'] = left  # [B, C=3, H, W]
                    img_summary['right'] = right  # [B, C=3, H, W]
                    img_summary['gt_disp'] = gt_disp  # [B, H, W]

                    if args.load_pseudo_gt:
                        img_summary['pseudo_gt_disp'] = pseudo_gt_disp

                    # Save pyramid disparity prediction
                    for s in range(len(pred_disp_pyramid)):
                        # Scale from low to high, reverse
                        save_name = 'pred_disp' + str(
                            len(pred_disp_pyramid) - s - 1)  # pred_disp0-->pred_disp4:高分辨率->低分辨率
                        save_value = pred_disp_pyramid[s]
                        img_summary[save_name] = save_value

                    pred_disp = pred_disp_pyramid[-1]

                    if pred_disp.size(-1) != gt_disp.size(-1):
                        pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                        pred_disp = F.interpolate(pred_disp, size=(gt_disp.size(-2), gt_disp.size(-1)),
                                                  mode='bilinear', align_corners=False) * (
                                                gt_disp.size(-1) / pred_disp.size(-1))
                        pred_disp = pred_disp.squeeze(1)  # [B, H, W]
                    img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)  # [B, C=3, H, W]

                    save_images(self.train_writer, 'train', img_summary, self.num_iter)

                    epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')

                    self.train_writer.add_scalar('train/epe', epe.item(), self.num_iter)
                    self.train_writer.add_scalar('train/disp_loss', disp_loss.item(), self.num_iter)
                    self.train_writer.add_scalar('train/total_loss', total_loss.item(), self.num_iter)

                    # Save loss of different scale
                    for s in range(len(pyramid_loss)):
                        save_name = 'train/loss' + str(len(pyramid_loss) - s - 1)  # loss0-->loss4:低分辨率~高分辨率
                        save_value = pyramid_loss[s]
                        self.train_writer.add_scalar(save_name, save_value, self.num_iter)

                    d1 = d1_metric(pred_disp, gt_disp, mask)  # pred_disp.shape=[B, H, W], gt_disp.shape=[B, H, W]
                    self.train_writer.add_scalar('train/d1', d1.item(), self.num_iter)
                    thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)  # mask.shape=[B, H, W]
                    thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
                    thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)
                    thres10 = thres_metric(pred_disp, gt_disp, mask, 10.0)
                    thres20 = thres_metric(pred_disp, gt_disp, mask, 20.0)
                    self.train_writer.add_scalar('train/thres1', thres1.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres2', thres2.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres3', thres3.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres10', thres10.item(), self.num_iter)
                    self.train_writer.add_scalar('train/thres20', thres20.item(), self.num_iter)

        self.epoch += 1

        # 记录数据为matlab的mat文件,用于分析和对比
        trainLoss_dict[trainLossKey]['epochs'].append(self.epoch)
        trainLoss_dict[trainLossKey]['avgEPE'].append(total_epe / validate_count)
        trainLoss_dict[trainLossKey]['avg_d1'].append(total_d1 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres1'].append(total_thres1 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres2'].append(total_thres2 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres3'].append(total_thres3 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres10'].append(total_thres10 / validate_count)
        trainLoss_dict[trainLossKey]['avg_thres20'].append(total_thres20 / validate_count)
        trainLoss_dict[trainLossKey]['avg_loss'].append(loss_acum / validate_count)

        # 一个epoch结束:
        # args.no_validate=False,则后面不会做self.validate(),故需要在此处记录如下信息。
        # args.no_validate=True,则后面会做self.validate(),会在elf.validate()中记录如下信息,此处不必记录。
        # 需记录的信息包括:
        # 1.最新的训练的模型和状态(写入aanet_latest.pth、optimizer_latest.pth文件) for resuming training;
        # 2.Save checkpoint of specific epoch.
        # Always save the latest model for resuming training
        if args.no_validate:
            utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                  epoch=self.epoch, num_iter=self.num_iter,
                                  epe=-1, best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth') if local_master else None

            # Save checkpoint of specific epoch
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir, self.optimizer, self.aanet,
                                      epoch=self.epoch, num_iter=self.num_iter,
                                      epe=-1, best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False) if local_master else None
Пример #13
0
    def validate(self, val_loader, local_master, valLossDict, valLossKey):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')
        # 只做evaluate,则需要从文件加载训练好的模型。否则,直接使用本model类中保存的(尚未完成全部的Epoach训练的)self.aanet即可。
        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir, model_name)
                if not os.path.exists(pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet, pretrained_aanet, no_strict=True)

        self.aanet.eval()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        val_thres1 = 0
        val_thres2 = 0
        val_thres3 = 0
        val_thres10 = 0
        val_thres20 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        # 遍历验证样本或测试样本
        for i, sample in enumerate(val_loader):
            if (i + 1) % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            mask = (gt_disp > 0) & (gt_disp < args.max_disp)

            if not mask.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                disparity_pyramid = self.aanet(left, right)  # [B, H, W]
                pred_disp = disparity_pyramid[-1]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                                          mode='bilinear', align_corners=False) * (
                                        gt_disp.size(-1) / pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            epe = F.l1_loss(gt_disp[mask], pred_disp[mask], reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask)
            thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)
            thres10 = thres_metric(pred_disp, gt_disp, mask, 10.0)
            thres20 = thres_metric(pred_disp, gt_disp, mask, 20.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_thres1 += thres1.item()
            val_thres2 += thres2.item()
            val_thres3 += thres3.item()
            val_thres10 += thres10.item()
            val_thres20 += thres20.item()

            # save Image For Error Analysis
            # saveForErrorAnalysis(index, img_name, dstPath, dstName, left, right, gt_disp, disparity_pyramid):
            with torch.no_grad():
                saveImgErrorAnalysis(i, sample['left_name'], './myDataAnalysis', 'SceneFlow_valIdx_{}'.format(i),
                                     left, right, gt_disp, disparity_pyramid, disp_error_img(pred_disp, gt_disp))

            # Save 3 images for visualization
            if not args.evaluate_only or args.mode == 'test':
                # if i in [num_samples // 4, num_samples // 2, num_samples // 4 * 3]:
                if i in [num_samples // 6, num_samples // 6 * 2, num_samples // 6 * 3, num_samples // 6 * 4,
                         num_samples // 6 * 5]:
                    img_summary = dict()
                    img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)
                    img_summary['left'] = left
                    img_summary['right'] = right
                    img_summary['gt_disp'] = gt_disp
                    img_summary['pred_disp'] = pred_disp
                    save_images(self.train_writer, 'val' + str(val_count), img_summary, self.epoch)

                    disp_error = disp_error_hist(pred_disp, gt_disp, args.max_disp)
                    save_hist(self.train_writer, '{}/{}'.format('val' + str(val_count), 'hist'), disp_error, self.epoch)

                    val_count += 1
        # 遍历验证样本或测试样本完成

        logger.info('=> Validation done!')

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_thres1 = val_thres1 / valid_samples
        mean_thres2 = val_thres2 / valid_samples
        mean_thres3 = val_thres3 / valid_samples
        mean_thres10 = val_thres10 / valid_samples
        mean_thres20 = val_thres20 / valid_samples

        # 记录数据为matlab的mat文件,用于分析和对比
        valLossDict[valLossKey]["epochs"].append(self.epoch)
        valLossDict[valLossKey]["avgEPE"].append(mean_epe)
        valLossDict[valLossKey]["avg_d1"].append(mean_d1)
        valLossDict[valLossKey]["avg_thres1"].append(mean_thres1)
        valLossDict[valLossKey]["avg_thres2"].append(mean_thres2)
        valLossDict[valLossKey]["avg_thres3"].append(mean_thres3)
        valLossDict[valLossKey]["avg_thres10"].append(mean_thres10)
        valLossDict[valLossKey]["avg_thres20"].append(mean_thres20)

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.3f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('thres1: %.4f\t' % mean_thres1)
            f.write('thres2: %.4f\t' % mean_thres2)
            f.write('thres3: %.4f\t' % mean_thres3)
            f.write('thres10: %.4f\t' % mean_thres10)
            f.write('thres20: %.4f\n' % mean_thres20)
            f.write('dataset_name= %s\t mode=%s\n' % (args.dataset_name, args.mode))

        logger.info('=> Mean validation epe of epoch %d: %.3f' % (self.epoch, mean_epe))

        if not args.evaluate_only:
            self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
            self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
            self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
            self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
            self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)
            self.train_writer.add_scalar('val/thres10', mean_thres10, self.epoch)
            self.train_writer.add_scalar('val/thres20', mean_thres20, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                          epoch=self.epoch, num_iter=self.num_iter,
                                          epe=mean_d1, best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth') if local_master else None
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                          epoch=self.epoch, num_iter=self.num_iter,
                                          epe=mean_epe, best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth') if local_master else None
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' % (self.best_epoch,
                                                                     args.val_metric,
                                                                     self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' % (self.best_epoch,
                                                                    args.val_metric,
                                                                    self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir, self.optimizer, self.aanet,
                                  epoch=self.epoch, num_iter=self.num_iter,
                                  epe=mean_epe, best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth') if local_master else None

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir, self.optimizer, self.aanet,
                                      epoch=self.epoch, num_iter=self.num_iter,
                                      epe=mean_epe, best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False) if local_master else None
Пример #14
0
    def validate(self, val_loader):
        args = self.args
        logger = self.logger
        logger.info('=> Start validation...')

        if args.evaluate_only is True:
            if args.pretrained_aanet is not None:
                pretrained_aanet = args.pretrained_aanet
            else:
                model_name = 'aanet_best.pth'
                pretrained_aanet = os.path.join(args.checkpoint_dir,
                                                model_name)
                if not os.path.exists(
                        pretrained_aanet):  # KITTI without validation
                    pretrained_aanet = pretrained_aanet.replace(
                        model_name, 'aanet_latest.pth')

            logger.info('=> loading pretrained aanet: %s' % pretrained_aanet)
            utils.load_pretrained_net(self.aanet,
                                      pretrained_aanet,
                                      no_strict=True)

        self.aanet.train()

        num_samples = len(val_loader)
        logger.info('=> %d samples found in the validation set' % num_samples)

        val_epe = 0
        val_d1 = 0
        # val_thres1 = 0
        # val_thres2 = 0
        # val_thres3 = 0
        val_bad1 = 0
        val_bad2 = 0
        val_abs = 0
        val_mm2 = 0
        val_mm4 = 0
        val_mm8 = 0

        val_count = 0

        val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')

        num_imgs = 0
        valid_samples = 0

        baseline = 0.055
        intrinsic = [[1387.095, 0.0, 960.0], [0.0, 1387.095, 540.0],
                     [0.0, 0.0, 1.0]]

        for i, sample in enumerate(val_loader):
            if i % 100 == 0:
                logger.info('=> Validating %d/%d' % (i, num_samples))

            left = sample['left'].to(self.device)  # [B, 3, H, W]
            right = sample['right'].to(self.device)
            gt_disp = sample['disp'].to(self.device)  # [B, H, W]
            gt_depth = []
            pred_disp = []

            if args.dataset_name == 'custom_dataset':  # going to be depthL_fromR_down if from  custom_dataset
                gt_disp_1 = (baseline * 1000 * intrinsic[0][0] /
                             2) / (gt_disp * 256.)
                gt_disp_1[gt_disp_1 == inf] = 0
                gt_depth = gt_disp * 256.
                gt_disp = gt_disp_1

            if (args.dataset_name == 'custom_dataset_full'
                    or args.dataset_name == 'custom_dataset_obj'):

                # convert to disparity then apply warp ops
                temp = gt_disp * 256.
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    temp[x] = (baseline * 1000 * intrinsic[0][0] /
                               2) / (temp[x])
                    temp[x][temp[x] == inf] = 0

                # gt_disp = torch.clone(temp)
                gt_disp = apply_disparity_cu(temp.unsqueeze(1),
                                             temp.type(torch.int))
                gt_disp = torch.squeeze(gt_disp)

                gt_depth = temp
                # convert to gt_depth
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    gt_depth[x] = (baseline * 1000 * intrinsic[0][0] /
                                   2) / (gt_disp[x])
                    gt_depth[x][gt_depth[x] == inf] = 0
                gt_depth = gt_depth.to(self.device)

            if (args.dataset_name == 'custom_dataset_sim'
                    or args.dataset_name == 'custom_dataset_real'):
                temp = gt_disp * 256.
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    temp[x] = (baseline * 1000 * intrinsic[0][0] /
                               2) / (temp[x])
                    temp[x][temp[x] == inf] = 0

                # gt_disp = torch.clone(temp)
                gt_disp = apply_disparity_cu(temp.unsqueeze(1),
                                             temp.type(torch.int))
                gt_disp = torch.squeeze(gt_disp)

                gt_disp = torch.unsqueeze(gt_disp, 0)

                gt_depth = temp
                # convert to gt_depth
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    gt_depth[x] = (baseline * 1000 * intrinsic[0][0] /
                                   2) / (gt_disp[x])
                    gt_depth[x][gt_depth[x] == inf] = 0
                gt_depth = gt_depth.to(self.device)

            mask_disp = (gt_disp > 0.) & (gt_disp < args.max_disp)

            if not mask_disp.any():
                continue

            valid_samples += 1

            num_imgs += gt_disp.size(0)

            with torch.no_grad():
                pred_disp = self.aanet(left, right)[-1]  # [B, H, W]

            if pred_disp.size(-1) < gt_disp.size(-1):
                pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
                pred_disp = F.interpolate(
                    pred_disp, (gt_disp.size(-2), gt_disp.size(-1)),
                    mode='bilinear',
                    align_corners=False) * (gt_disp.size(-1) /
                                            pred_disp.size(-1))
                pred_disp = pred_disp.squeeze(1)  # [B, H, W]

            if (onlyObj):
                gt_disp[sample['label'] >= 17] = 0
                mask_disp = (gt_disp > 0.) & (gt_disp < args.max_disp)

            epe = F.l1_loss(gt_disp[mask_disp],
                            pred_disp[mask_disp],
                            reduction='mean')
            d1 = d1_metric(pred_disp, gt_disp, mask_disp)

            bad1 = bad(pred_disp, gt_disp, mask_disp)
            bad2 = bad(pred_disp, gt_disp, mask_disp, threshold=2)

            pred_depth = []
            if (args.dataset_name == 'custom_dataset_full'
                    or args.dataset_name == 'custom_dataset_sim'
                    or args.dataset_name == 'custom_dataset_real'
                    or args.dataset_name == 'custom_dataset_obj'):
                temp = torch.zeros((pred_disp.shape)).to(self.device)
                for x in range(left.shape[0]):
                    baseline = sample['baseline'][x].to(self.device)
                    intrinsic = sample['intrinsic'][x].to(self.device)
                    temp[x] = (baseline * 1000 * intrinsic[0][0] /
                               2) / (pred_disp[x])
                    temp[x][temp[x] == inf] = 0
                pred_depth = temp
            else:
                pred_depth = (baseline * 1000 * intrinsic[0][0] /
                              2) / (pred_disp)
                pred_depth[pred_depth == inf] = 0

            mask_depth = (gt_depth > 0.) & (gt_depth < 2000)

            if (onlyObj):
                gt_depth[sample['label'] >= 17] = 0
                mask_depth = (gt_depth > 0.) & (gt_disp < args.max_disp)

            abs = F.l1_loss(gt_depth[mask_depth],
                            pred_depth[mask_depth],
                            reduction='mean')

            mm2 = mm_error(pred_depth, gt_depth, mask_depth)
            mm4 = mm_error(pred_depth, gt_depth, mask_depth, threshold=4)
            mm8 = mm_error(pred_depth, gt_depth, mask_depth, threshold=8)

            pred_depth[pred_depth > 2000] = 0

            if (perObject):
                for x in range(left.shape[0]):
                    labels = sample['label'][x].detach().numpy().astype(
                        np.uint8)
                    for obj in np.unique(labels):
                        gtObjectDepth = gt_depth[x].detach().clone()
                        gtObjectDepth[labels != obj] = 0
                        predObjectDepth = pred_depth[x].detach().clone()
                        predObjectDepth[labels != obj] = 0

                        gtObjectDisp = gt_disp[x].detach().clone()
                        gtObjectDisp[labels != obj] = 0
                        predObjectDisp = pred_disp[x].detach().clone()
                        predObjectDisp[labels != obj] = 0

                        mask_depth = (gtObjectDepth > 0.)
                        mask_disp = (gtObjectDisp > 0.)

                        objectCount[obj] += 1

                        perObjectDisp[obj] += F.l1_loss(
                            gtObjectDisp[mask_disp],
                            predObjectDisp[mask_disp],
                            reduction='mean')
                        perObjectDepth[obj] += F.l1_loss(
                            gtObjectDepth[mask_depth],
                            predObjectDepth[mask_depth],
                            reduction='mean')

            # thres1 = thres_metric(pred_disp, gt_disp, mask, 1.0)
            # thres2 = thres_metric(pred_disp, gt_disp, mask, 2.0)
            # thres3 = thres_metric(pred_disp, gt_disp, mask, 3.0)

            val_epe += epe.item()
            val_d1 += d1.item()
            val_bad1 += bad1.item()
            val_bad2 += bad2.item()
            val_abs += abs.item()
            val_mm2 += mm2.item()
            val_mm4 += mm4.item()
            val_mm8 += mm8.item()
            # val_thres1 += thres1.item()
            # val_thres2 += thres2.item()
            # val_thres3 += thres3.item()

            # Save 3 images for visualization

            if i in [num_samples // 4, num_samples // 2, num_samples // 4 * 3]:
                if args.evaluate_only:

                    im = (pred_depth[0]).detach().cpu().numpy().astype(
                        np.uint16)
                    if not os.path.isdir('/cephfs/edward/depths'):
                        os.mkdir('/cephfs/edward/depths')
                    imageio.imwrite('/cephfs/edward/depths/' + str(i) + ".png",
                                    im)

                    im = (gt_depth[0]).detach().cpu().numpy().astype(np.uint16)
                    imageio.imwrite(
                        '/cephfs/edward/depths/' + str(i) + "gt.png", im)

                    imageio.imwrite(
                        '/cephfs/edward/depths/' + str(i) + "label.png",
                        sample['label'][x].detach().numpy().astype(np.uint8))

                    info = {
                        'baseline': sample['baseline'][x],
                        'intrinsic': sample['intrinsic'][x],
                        'object_ids': sample['object_ids'][x],
                        'extrinsic': sample['extrinsic'][x]
                    }
                    filename = '/cephfs/edward/depths/meta' + str(i) + '.pkl'
                    with open(filename, 'wb') as f:
                        pickle.dump(info, f)

                img_summary = {}
                img_summary['left'] = left
                img_summary['right'] = right
                img_summary['gt_depth'] = gt_depth
                img_summary['gt_disp'] = gt_disp

                if (onlyObj):
                    pred_disp[sample['label'] >= 17] = 0
                    pred_depth[sample['label'] >= 17] = 0

                img_summary['disp_error'] = disp_error_img(pred_disp, gt_disp)
                img_summary['depth_error'] = depth_error_img(
                    pred_depth, gt_depth)
                img_summary['pred_disp'] = pred_disp
                img_summary['pred_depth'] = pred_depth

                save_images(self.train_writer, 'val' + str(val_count),
                            img_summary, self.epoch)
                val_count += 1

        logger.info('=> Validation done!')
        if (perObject):
            for key, value in objectCount.items():

                perObjectDisp[key] = float(perObjectDisp[key]) / value
                perObjectDepth[key] = float(perObjectDepth[key]) / value
            print(perObjectDisp, perObjectDepth, objectCount)

        mean_epe = val_epe / valid_samples
        mean_d1 = val_d1 / valid_samples
        mean_bad1 = val_bad1 / valid_samples
        mean_bad2 = val_bad2 / valid_samples
        mean_abs = val_abs / valid_samples
        mean_mm2 = val_mm2 / valid_samples
        mean_mm4 = val_mm4 / valid_samples
        mean_mm8 = val_mm8 / valid_samples
        # mean_thres1 = val_thres1 / valid_samples
        # mean_thres2 = val_thres2 / valid_samples
        # mean_thres3 = val_thres3 / valid_samples

        # Save validation results
        with open(val_file, 'a') as f:
            f.write('epoch: %03d\t' % self.epoch)
            f.write('epe: %.4f\t' % mean_epe)
            f.write('d1: %.4f\t' % mean_d1)
            f.write('bad1: %.4f\t' % mean_bad1)
            f.write('bad2: %.4f\t' % mean_bad2)
            f.write('abs: %.4f\t' % mean_abs)
            f.write('mm2: %.4f\t' % mean_mm2)
            f.write('mm4: %.4f\t' % mean_mm4)
            f.write('mm8: %.4f\t' % mean_mm8)
            # f.write('thres1: %.4f\t' % mean_thres1)
            # f.write('thres2: %.4f\t' % mean_thres2)
            # f.write('thres3: %.4f\n' % mean_thres3)

        logger.info('=> Mean validation epe of epoch %d: %.3f' %
                    (self.epoch, mean_epe))

        self.train_writer.add_scalar('val/epe', mean_epe, self.epoch)
        self.train_writer.add_scalar('val/d1', mean_d1, self.epoch)
        self.train_writer.add_scalar('val/bad1', mean_bad1, self.epoch)
        self.train_writer.add_scalar('val/bad2', mean_bad2, self.epoch)
        self.train_writer.add_scalar('val/abs', mean_abs, self.epoch)
        self.train_writer.add_scalar('val/mm2', mean_mm2, self.epoch)
        self.train_writer.add_scalar('val/mm4', mean_mm4, self.epoch)
        self.train_writer.add_scalar('val/mm8', mean_mm8, self.epoch)
        # self.train_writer.add_scalar('val/thres1', mean_thres1, self.epoch)
        # self.train_writer.add_scalar('val/thres2', mean_thres2, self.epoch)
        # self.train_writer.add_scalar('val/thres3', mean_thres3, self.epoch)

        if not args.evaluate_only:
            if args.val_metric == 'd1':
                if mean_d1 < self.best_epe:
                    # Actually best_epe here is d1
                    self.best_epe = mean_d1
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_d1,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            elif args.val_metric == 'epe':
                if mean_epe < self.best_epe:
                    self.best_epe = mean_epe
                    self.best_epoch = self.epoch

                    utils.save_checkpoint(args.checkpoint_dir,
                                          self.optimizer,
                                          self.aanet,
                                          epoch=self.epoch,
                                          num_iter=self.num_iter,
                                          epe=mean_epe,
                                          best_epe=self.best_epe,
                                          best_epoch=self.best_epoch,
                                          filename='aanet_best.pth')
            else:
                raise NotImplementedError

        if self.epoch == args.max_epoch:
            # Save best validation results
            with open(val_file, 'a') as f:
                f.write('\nbest epoch: %03d \t best %s: %.3f\n\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

            logger.info('=> best epoch: %03d \t best %s: %.3f\n' %
                        (self.best_epoch, args.val_metric, self.best_epe))

        # Always save the latest model for resuming training
        if not args.evaluate_only:
            utils.save_checkpoint(args.checkpoint_dir,
                                  self.optimizer,
                                  self.aanet,
                                  epoch=self.epoch,
                                  num_iter=self.num_iter,
                                  epe=mean_epe,
                                  best_epe=self.best_epe,
                                  best_epoch=self.best_epoch,
                                  filename='aanet_latest.pth')

            # Save checkpoint of specific epochs
            if self.epoch % args.save_ckpt_freq == 0:
                model_dir = os.path.join(args.checkpoint_dir, 'models')
                utils.check_path(model_dir)
                utils.save_checkpoint(model_dir,
                                      self.optimizer,
                                      self.aanet,
                                      epoch=self.epoch,
                                      num_iter=self.num_iter,
                                      epe=mean_epe,
                                      best_epe=self.best_epe,
                                      best_epoch=self.best_epoch,
                                      save_optimizer=False)