コード例 #1
0
    def __init__(
        self,
        adj,
        x,
        model,
        label,
        args,
        graph_idx=0,
        writer=None,
        use_sigmoid=True,
        graph_mode=False,
    ):
        super(ExplainModule, self).__init__()
        self.adj = adj
        self.x = x
        self.model = model
        self.label = label
        self.graph_idx = graph_idx
        self.args = args
        self.writer = writer
        self.mask_act = args.mask_act
        self.use_sigmoid = use_sigmoid
        self.graph_mode = graph_mode

        init_strategy = "normal"
        num_nodes = adj.size()[1]
        self.mask, self.mask_bias = self.construct_edge_mask(
            num_nodes, init_strategy=init_strategy)

        self.feat_mask = self.construct_feat_mask(x.size(-1),
                                                  init_strategy="constant")
        params = [self.mask, self.feat_mask]
        if self.mask_bias is not None:
            params.append(self.mask_bias)
        # For masking diagonal entries
        self.diag_mask = torch.ones(num_nodes,
                                    num_nodes) - torch.eye(num_nodes)
        if args.gpu:
            self.diag_mask = self.diag_mask.cuda()

        self.scheduler, self.optimizer = train_utils.build_optimizer(
            args, params)

        self.coeffs = {
            "size": 0.005,
            "feat_size": 1.0,
            "ent": 1.0,
            "feat_ent": 0.1,
            "grad": 0,
            "lap": 1.0,
        }
コード例 #2
0
ファイル: train.py プロジェクト: Aditya239233/GNNExplainer
def train_node_classifier(G, labels, model, args, writer=None):
    # train/test split only for nodes
    num_nodes = G.number_of_nodes()
    num_train = int(num_nodes * args.train_ratio)
    idx = [i for i in range(num_nodes)]

    np.random.shuffle(idx)
    train_idx = idx[:num_train]
    test_idx = idx[num_train:]

    data = gengraph.preprocess_input_graph(G, labels)
    labels_train = torch.tensor(data["labels"][:, train_idx], dtype=torch.long)
    adj = torch.tensor(data["adj"], dtype=torch.float)
    x = torch.tensor(data["feat"], requires_grad=True, dtype=torch.float)
    scheduler, optimizer = train_utils.build_optimizer(
        args, model.parameters(), weight_decay=args.weight_decay)
    model.train()
    ypred = None
    for epoch in range(args.num_epochs):
        begin_time = time.time()
        model.zero_grad()

        if args.gpu:
            ypred, adj_att = model(x.cuda(), adj.cuda())
        else:
            ypred, adj_att = model(x, adj)
        ypred_train = ypred[:, train_idx, :]
        if args.gpu:
            loss = model.loss(ypred_train, labels_train.cuda())
        else:
            loss = model.loss(ypred_train, labels_train)
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.clip)

        optimizer.step()
        #for param_group in optimizer.param_groups:
        #    print(param_group["lr"])
        elapsed = time.time() - begin_time

        result_train, result_test = evaluate_node(ypred.cpu(), data["labels"],
                                                  train_idx, test_idx)
        if writer is not None:
            writer.add_scalar("loss/avg_loss", loss, epoch)
            writer.add_scalars(
                "prec",
                {
                    "train": result_train["prec"],
                    "test": result_test["prec"]
                },
                epoch,
            )
            writer.add_scalars(
                "recall",
                {
                    "train": result_train["recall"],
                    "test": result_test["recall"]
                },
                epoch,
            )
            writer.add_scalars("acc", {
                "train": result_train["acc"],
                "test": result_test["acc"]
            }, epoch)

        if epoch % 10 == 0:
            print(
                "epoch: ",
                epoch,
                "; loss: ",
                loss.item(),
                "; train_acc: ",
                result_train["acc"],
                "; test_acc: ",
                result_test["acc"],
                "; train_prec: ",
                result_train["prec"],
                "; test_prec: ",
                result_test["prec"],
                "; epoch time: ",
                "{0:0.2f}".format(elapsed),
            )

        if scheduler is not None:
            scheduler.step()
    print(result_train["conf_mat"])
    print(result_test["conf_mat"])

    # computation graph
    model.eval()
    if args.gpu:
        ypred, _ = model(x.cuda(), adj.cuda())
    else:
        ypred, _ = model(x, adj)
    cg_data = {
        "adj": data["adj"],
        "feat": data["feat"],
        "label": data["labels"],
        "pred": ypred.cpu().detach().numpy(),
        "train_idx": train_idx,
    }
    # import pdb
    # pdb.set_trace()
    io_utils.save_checkpoint(model,
                             optimizer,
                             args,
                             num_epochs=-1,
                             cg_dict=cg_data)
コード例 #3
0
def train_node_classifier_multigraph(G_list, labels, model, args, writer=None):
    train_idx_all, test_idx_all = [], []
    # train/test split only for nodes
    num_nodes = G_list[0].number_of_nodes()
    num_train = int(num_nodes * args.train_ratio)
    idx = [i for i in range(num_nodes)]
    np.random.shuffle(idx)
    train_idx = idx[:num_train]
    train_idx_all.append(train_idx)
    test_idx = idx[num_train:]
    test_idx_all.append(test_idx)

    data = gengraph.preprocess_input_graph(G_list[0], labels[0])
    all_labels = data["labels"]
    labels_train = torch.tensor(data["labels"][:, train_idx], dtype=torch.long)
    adj = torch.tensor(data["adj"], dtype=torch.float)
    x = torch.tensor(data["feat"], requires_grad=True, dtype=torch.float)

    for i in range(1, len(G_list)):
        np.random.shuffle(idx)
        train_idx = idx[:num_train]
        train_idx_all.append(train_idx)
        test_idx = idx[num_train:]
        test_idx_all.append(test_idx)
        data = gengraph.preprocess_input_graph(G_list[i], labels[i])
        all_labels = np.concatenate((all_labels, data["labels"]), axis=0)
        labels_train = torch.cat(
            [
                labels_train,
                torch.tensor(data["labels"][:, train_idx], dtype=torch.long),
            ],
            dim=0,
        )
        adj = torch.cat([adj, torch.tensor(data["adj"], dtype=torch.float)])
        x = torch.cat([
            x,
            torch.tensor(data["feat"], requires_grad=True, dtype=torch.float)
        ])

    scheduler, optimizer = train_utils.build_optimizer(
        args, model.parameters(), weight_decay=args.weight_decay)
    model.train()
    ypred = None
    for epoch in range(args.num_epochs):
        begin_time = time.time()
        model.zero_grad()

        if args.gpu:
            ypred = model(x.cuda(), adj.cuda())
        else:
            ypred = model(x, adj)
        # normal indexing
        ypred_train = ypred[:, train_idx, :]
        # in multigraph setting we can't directly access all dimensions so we need to gather all the training instances
        all_train_idx = [item for sublist in train_idx_all for item in sublist]
        ypred_train_cmp = torch.cat(
            [ypred[i, train_idx_all[i], :] for i in range(10)],
            dim=0).reshape(10, 146, 6)
        if args.gpu:
            loss = model.loss(ypred_train_cmp, labels_train.cuda())
        else:
            loss = model.loss(ypred_train_cmp, labels_train)
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.clip)

        optimizer.step()
        #for param_group in optimizer.param_groups:
        #    print(param_group["lr"])
        elapsed = time.time() - begin_time

        result_train, result_test = evaluate_node(ypred.cpu(), all_labels,
                                                  train_idx_all, test_idx_all)
        if writer is not None:
            writer.add_scalar("loss/avg_loss", loss, epoch)
            writer.add_scalars(
                "prec",
                {
                    "train": result_train["prec"],
                    "test": result_test["prec"]
                },
                epoch,
            )
            writer.add_scalars(
                "recall",
                {
                    "train": result_train["recall"],
                    "test": result_test["recall"]
                },
                epoch,
            )
            writer.add_scalars("acc", {
                "train": result_train["acc"],
                "test": result_test["acc"]
            }, epoch)

        print(
            "epoch: ",
            epoch,
            "; loss: ",
            loss.item(),
            "; train_acc: ",
            result_train["acc"],
            "; test_acc: ",
            result_test["acc"],
            "; epoch time: ",
            "{0:0.2f}".format(elapsed),
        )

        if scheduler is not None:
            scheduler.step()
    print(result_train["conf_mat"])
    print(result_test["conf_mat"])

    # computation graph
    model.eval()
    if args.gpu:
        ypred = model(x.cuda(), adj.cuda())
    else:
        ypred = model(x, adj)
    cg_data = {
        "adj": adj.cpu().detach().numpy(),
        "feat": x.cpu().detach().numpy(),
        "label": all_labels,
        "pred": ypred.cpu().detach().numpy(),
        "train_idx": train_idx_all,
    }
    io_utils.save_checkpoint(model,
                             optimizer,
                             args,
                             num_epochs=-1,
                             cg_dict=cg_data)
コード例 #4
0
ファイル: train.py プロジェクト: abhinavkaul95/NMTree
def main(opt):
    # set random seed
    torch.manual_seed(opt.seed)
    random.seed(opt.seed)

    # initialize
    opt.dataset_split_by = opt.dataset + '_' + opt.split_by
    checkpoint_dir = os.path.join(opt.checkpoint_path,
                                  opt.dataset_split_by + '_' + opt.id)
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    logger = initialize_logger(os.path.join(checkpoint_dir, 'train.log'))
    print = logger.info

    # set up loader
    data_json = os.path.join(opt.feats_path, opt.dataset_split_by,
                             opt.data_file + '.json')
    data_pth = os.path.join(opt.feats_path, opt.dataset_split_by,
                            opt.data_file + '.pth')
    visual_feats_dir = os.path.join(opt.feats_path, opt.dataset_split_by,
                                    opt.visual_feat_file)

    if os.path.isfile(data_pth):
        loader = GtLoader(data_json, visual_feats_dir, opt, data_pth)
        opt.tag_vocab_size = loader.tag_vocab_size
        opt.dep_vocab_size = loader.dep_vocab_size
    else:
        loader = GtLoader(data_json, visual_feats_dir, opt)

    opt.word_vocab_size = loader.word_vocab_size
    opt.vis_dim = loader.vis_dim

    # print out the option variables
    print("*" * 20)
    for k, v in opt.__dict__.items():
        print("%r: %r" % (k, v))
    print("*" * 20)

    # load previous checkpoint if possible
    infos = {}
    if opt.start_from:
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos.json")
        ), "infos.json file does not exist in path %s" % opt.start_from
        print("Load infos ...")
        with open(os.path.join(opt.start_from, 'infos.json'), 'r') as f:
            infos = json.load(f)

    # resume checkpoint or from scratch
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    loader.iterators = infos.get('iterators', loader.iterators)

    # some histories may useful
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    val_loss_history = infos.get('val_loss_history', {})
    val_accuracies = infos.get('val_accuracies', [])
    test_accuracies = infos.get('test_accuracies', [])
    test_loss_history = infos.get('test_loss_history', {})
    best_val_score = infos.get('best_val_score', None)
    best_epoch = infos.get('best_epoch', 0)

    # set up model and criterion
    model = models.setup(opt, loader).cuda()
    crit = torch.nn.NLLLoss()

    # set up optimizer
    weights, biases = [], []
    for name, p in model.named_parameters():
        if 'bias' in name:
            biases += [p]
        else:
            weights += [p]
    optimizer = train_utils.build_optimizer(weights, biases, opt)

    # check compatibility if training is continued from previously saved model
    if opt.start_from:
        # check if all necessary files exist
        assert os.path.isfile(os.path.join(opt.start_from, "model.pth"))
        print("Load model ...")
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Load the optimizer
    if opt.start_from:
        assert os.path.isfile(os.path.join(opt.start_from, "optimizer.pth"))
        print("Load optimizer ...")
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Load the pretrained word vector
    if opt.glove and not opt.start_from:
        glove_weight = load_glove(glove=opt.glove,
                                  vocab=loader.word_to_ix,
                                  opt=opt)
        assert glove_weight.shape == model.word_embedding.weight.size()
        with torch.no_grad():
            model.word_embedding.weight.set_(
                torch.cuda.FloatTensor(glove_weight))
        print("Load word vectors ...")

    # start training
    tic = time.time()
    wrapped = False
    while True:
        model.train()

        # decay the learning rates
        if 0 <= opt.learning_rate_decay_start < epoch:
            frac = (epoch - opt.learning_rate_decay_start
                    ) // opt.learning_rate_decay_every
            decay_factor = opt.learning_rate_decay_rate**frac
            opt.current_lr = opt.learning_rate * decay_factor
            train_utils.set_lr(
                optimizer, opt.current_lr)  # update optimizer's learning rate
        else:
            opt.current_lr = opt.learning_rate

        # start training
        optimizer.zero_grad()
        total_loss = 0.0
        n = 0.0
        acc = 0.0
        for _ in range(opt.batch_size):
            # read data
            data = loader.get_data('train')

            wrapped = True if data['bounds']['wrapped'] else wrapped
            torch.cuda.synchronize()

            # model forward
            scores = model(data)
            target = data['gts']
            loss = crit(scores, target)
            loss.backward()
            total_loss += loss.detach().cpu().numpy()

            # compute accuracy
            scores = scores.data.cpu().numpy()
            gt_ix = target.detach().cpu().numpy()
            pred_ix = np.argmax(scores, axis=1)
            n += len(gt_ix)
            acc += sum(pred_ix == gt_ix)

        total_loss /= opt.batch_size
        train_accuracy = acc / n * 100

        # model backward
        train_utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        torch.cuda.synchronize()

        # write the training loss summary
        if iteration % opt.losses_log_every == 0:
            loss_history[iteration] = total_loss
            lr_history[iteration] = opt.current_lr
            print(
                'epoch=%d, iter=%d, train_loss=%.3f, train_acc=%.2f, time=%.2f'
                % (epoch, iteration, total_loss, train_accuracy,
                   time.time() - tic))
            tic = time.time()

        # eval loss and save checkpoint
        if wrapped and epoch % opt.save_checkpoint_every == 0:
            # evaluate models
            acc, n, val_loss = eval_utils.eval_gt_split(
                loader, model, crit, 'val', vars(opt))
            val_accuracy = acc / n * 100
            print("%s set evaluated. val_loss = %.2f, acc = %d / %d = %.2f%%" %
                  ('val', val_loss, acc, n, val_accuracy))
            val_loss_history[iteration] = val_loss
            val_accuracies.append(val_accuracy)

            if opt.split_by == 'unc':
                test_split = 'testA'
            else:
                test_split = 'test'
            test_acc, test_n, test_loss = eval_utils.eval_gt_split(
                loader, model, crit, test_split, vars(opt))
            test_accuracy = test_acc / test_n * 100
            print(
                "%s set evaluated. test_loss = %.2f, acc = %d / %d = %.2f%%" %
                (test_split, test_loss, test_acc, test_n, test_accuracy))
            test_loss_history[iteration] = test_loss
            test_accuracies.append(test_accuracy)

            # save model
            checkpoint_path = os.path.join(checkpoint_dir, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)
            print("model saved to {}".format(checkpoint_path))

            # save infos
            infos['iter'] = iteration + 1
            infos['epoch'] = epoch + 1
            infos['iterators'] = loader.iterators
            infos['opt'] = vars(opt)
            infos['word_to_ix'] = loader.word_to_ix

            # save histories
            infos['loss_history'] = loss_history
            infos['lr_history'] = lr_history
            infos['val_accuracies'] = val_accuracies
            infos['val_loss_history'] = val_loss_history
            infos['test_accuracies'] = test_accuracies
            infos['test_loss_history'] = test_loss_history
            infos['best_val_score'] = best_val_score
            infos['best_epoch'] = best_epoch

            # save model if best
            current_score = val_accuracy + test_accuracy
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                model_save_path = os.path.join(checkpoint_dir,
                                               'model-best.pth')
                torch.save(model.state_dict(), model_save_path)
                with open(os.path.join(checkpoint_dir, 'infos-best.json'),
                          'w') as f:
                    json.dump(infos, f, sort_keys=True, indent=4)
                print("model saved to {}".format(model_save_path))
            else:
                print("The best model in epoch{}: {}".format(
                    best_epoch, best_val_score))

            with open(os.path.join(checkpoint_dir, 'infos.json'), 'w') as f:
                json.dump(infos, f, sort_keys=True, indent=4)

        # update iteration and epoch
        iteration += 1
        if wrapped:
            wrapped = False
            epoch += 1
            loader.shuffle(split='train')

        if 0 < opt.max_epochs <= epoch:
            break
コード例 #5
0
def PolarOffsetMain(args, cfg):
    if args.launcher == None:
        dist_train = False
    else:
        args.batch_size, cfg.LOCAL_RANK = getattr(
            common_utils, 'init_dist_%s' % args.launcher)(args.batch_size,
                                                          args.tcp_port,
                                                          args.local_rank,
                                                          backend='nccl')
        dist_train = True
    cfg['DIST_TRAIN'] = dist_train
    output_dir = os.path.join('./output', args.tag)
    ckpt_dir = os.path.join(output_dir, 'ckpt')
    tmp_dir = os.path.join(output_dir, 'tmp')
    summary_dir = os.path.join(output_dir, 'summary')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir, exist_ok=True)

    if args.onlyval and args.saveval:
        results_dir = os.path.join(output_dir, 'test', 'sequences')
        if not os.path.exists(results_dir):
            os.makedirs(results_dir, exist_ok=True)
        for i in range(8, 9):
            sub_dir = os.path.join(results_dir, str(i).zfill(2), 'predictions')
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir, exist_ok=True)

    if args.onlytest:
        results_dir = os.path.join(output_dir, 'test', 'sequences')
        if not os.path.exists(results_dir):
            os.makedirs(results_dir, exist_ok=True)
        for i in range(11, 22):
            sub_dir = os.path.join(results_dir, str(i).zfill(2), 'predictions')
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir, exist_ok=True)

    log_file = os.path.join(
        output_dir, ('log_train_%s.txt' %
                     datetime.datetime.now().strftime('%Y%m%d-%H%M%S')))
    logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)

    logger.info('**********************Start logging**********************')
    gpu_list = os.environ[
        'CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(
        ) else 'ALL'
    logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

    if dist_train:
        total_gpus = dist.get_world_size()
        logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
    for key, val in vars(args).items():
        logger.info('{:16} {}'.format(key, val))
    log_config_to_file(cfg, logger=logger)
    if cfg.LOCAL_RANK == 0:
        os.system('cp %s %s' % (args.config, output_dir))

    ### create dataloader
    if (not args.onlytest) and (not args.onlyval):
        train_dataset_loader = build_dataloader(args,
                                                cfg,
                                                split='train',
                                                logger=logger)
        val_dataset_loader = build_dataloader(args,
                                              cfg,
                                              split='val',
                                              logger=logger,
                                              no_shuffle=True,
                                              no_aug=True)
    elif args.onlyval:
        val_dataset_loader = build_dataloader(args,
                                              cfg,
                                              split='val',
                                              logger=logger,
                                              no_shuffle=True,
                                              no_aug=True)
    else:
        test_dataset_loader = build_dataloader(args,
                                               cfg,
                                               split='test',
                                               logger=logger,
                                               no_shuffle=True,
                                               no_aug=True)

    ### create model
    model = build_network(cfg)
    model.cuda()

    ### create optimizer
    optimizer = train_utils.build_optimizer(model, cfg)

    ### load ckpt
    ckpt_fname = os.path.join(ckpt_dir, args.ckpt_name)
    epoch = -1

    other_state = {}
    if args.pretrained_ckpt is not None and os.path.exists(ckpt_fname):
        logger.info(
            "Now in pretrain mode and loading ckpt: {}".format(ckpt_fname))
        if not args.nofix:
            if args.fix_semantic_instance:
                logger.info(
                    "Freezing backbone, semantic and instance part of the model."
                )
                model.fix_semantic_instance_parameters()
            else:
                logger.info(
                    "Freezing semantic and backbone part of the model.")
                model.fix_semantic_parameters()
        optimizer = train_utils.build_optimizer(model, cfg)
        epoch, other_state = train_utils.load_params_with_optimizer_otherstate(
            model,
            ckpt_fname,
            to_cpu=dist_train,
            optimizer=optimizer,
            logger=logger)  # new feature
        logger.info("Loaded Epoch: {}".format(epoch))
    elif args.pretrained_ckpt is not None:
        train_utils.load_pretrained_model(model,
                                          args.pretrained_ckpt,
                                          to_cpu=dist_train,
                                          logger=logger)
        if not args.nofix:
            if args.fix_semantic_instance:
                logger.info(
                    "Freezing backbone, semantic and instance part of the model."
                )
                model.fix_semantic_instance_parameters()
            else:
                logger.info(
                    "Freezing semantic and backbone part of the model.")
                model.fix_semantic_parameters()
        else:
            logger.info("No Freeze.")
        optimizer = train_utils.build_optimizer(model, cfg)
    elif os.path.exists(ckpt_fname):
        epoch, other_state = train_utils.load_params_with_optimizer_otherstate(
            model,
            ckpt_fname,
            to_cpu=dist_train,
            optimizer=optimizer,
            logger=logger)  # new feature
        logger.info("Loaded Epoch: {}".format(epoch))
    if other_state is None:
        other_state = {}

    ### create optimizer and scheduler
    lr_scheduler = None
    if lr_scheduler == None:
        logger.info('Not using lr scheduler')

    model.train(
    )  # before wrap to DistributedDataParallel to support fixed some parameters
    if dist_train:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()],
            find_unused_parameters=True)
    logger.info(model)

    if cfg.LOCAL_RANK == 0:
        writer = SummaryWriter(log_dir=summary_dir)

    logger.info('**********************Start Training**********************')
    rank = cfg.LOCAL_RANK
    best_before_iou = -1 if 'best_before_iou' not in other_state else other_state[
        'best_before_iou']
    best_pq = -1 if 'best_pq' not in other_state else other_state['best_pq']
    best_after_iou = -1 if 'best_after_iou' not in other_state else other_state[
        'best_after_iou']
    global_iter = 0 if 'global_iter' not in other_state else other_state[
        'global_iter']
    val_global_iter = 0 if 'val_global_iter' not in other_state else other_state[
        'val_global_iter']
    best_tracking_loss = 10086 if 'best_tracking_loss' not in other_state else other_state[
        'best_tracking_loss']

    ### test
    if args.onlytest:
        logger.info('----EPOCH {} Testing----'.format(epoch))
        model.eval()
        if rank == 0:
            vbar = tqdm(total=len(test_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(test_dataset_loader):
            with torch.no_grad():
                if cfg.MODEL.NAME.startswith(
                        'PolarOffsetSpconvPytorchMeanshiftTracking'
                ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                    ret_dict = model(inputs,
                                     is_test=True,
                                     merge_evaluator_list=None,
                                     merge_evaluator_window_k_list=None,
                                     require_cluster=True)
                else:
                    ret_dict = model(inputs,
                                     is_test=True,
                                     require_cluster=True,
                                     require_merge=True)
                common_utils.save_test_results(ret_dict, results_dir, inputs)
            if rank == 0:
                vbar.set_postfix({
                    'fname':
                    '/'.join(inputs['pcd_fname'][0].split('/')[-3:])
                })
                vbar.update(1)
        if rank == 0:
            vbar.close()
        logger.info("----Testing Finished----")
        return

    ### evaluate
    if args.onlyval:
        logger.info('----EPOCH {} Evaluating----'.format(epoch))
        model.eval()
        min_points = 50  # according to SemanticKITTI official rule
        if cfg.MODEL.NAME.startswith(
                'PolarOffsetSpconvPytorchMeanshiftTracking'
        ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
            merge_evaluator_list = []
            merge_evaluator_window_k_list = []
            for k in [1, 5, 10, 15]:
                merge_evaluator_list.append(init_eval(min_points))
                merge_evaluator_window_k_list.append(k)
        else:
            before_merge_evaluator = init_eval(min_points=min_points)
            after_merge_evaluator = init_eval(min_points=min_points)
        if rank == 0:
            vbar = tqdm(total=len(val_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(val_dataset_loader):
            inputs['i_iter'] = i_iter
            # torch.cuda.empty_cache()
            with torch.no_grad():
                if cfg.MODEL.NAME.startswith(
                        'PolarOffsetSpconvPytorchMeanshiftTracking'
                ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                    ret_dict = model(inputs,
                                     is_test=True,
                                     merge_evaluator_list=merge_evaluator_list,
                                     merge_evaluator_window_k_list=
                                     merge_evaluator_window_k_list,
                                     require_cluster=True)
                else:
                    ret_dict = model(
                        inputs,
                        is_test=True,
                        before_merge_evaluator=before_merge_evaluator,
                        after_merge_evaluator=after_merge_evaluator,
                        require_cluster=True)
                #########################
                # with open('./ipnb/{}_matching_list.pkl'.format(i_iter), 'wb') as fd:
                #     pickle.dump(ret_dict['matching_list'], fd)
                #########################
                if args.saveval:
                    common_utils.save_test_results(ret_dict, results_dir,
                                                   inputs)
            if rank == 0:
                vbar.set_postfix({
                    'loss':
                    ret_dict['loss'].item(),
                    'fname':
                    '/'.join(inputs['pcd_fname'][0].split('/')[-3:]),
                    'ins_num':
                    -1 if 'ins_num' not in ret_dict else ret_dict['ins_num']
                })
                vbar.update(1)
        if dist_train:
            if cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                pass
            else:
                before_merge_evaluator = common_utils.merge_evaluator(
                    before_merge_evaluator, tmp_dir)
                dist.barrier()
                after_merge_evaluator = common_utils.merge_evaluator(
                    after_merge_evaluator, tmp_dir)

        if rank == 0:
            vbar.close()
        if rank == 0:
            ## print results
            if cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                for evaluate, window_k in zip(merge_evaluator_list,
                                              merge_evaluator_window_k_list):
                    logger.info("Current Window K: {}".format(window_k))
                    printResults(evaluate, logger=logger)
            else:
                logger.info("Before Merge Semantic Scores")
                before_merge_results = printResults(before_merge_evaluator,
                                                    logger=logger,
                                                    sem_only=True)
                logger.info("After Merge Panoptic Scores")
                after_merge_results = printResults(after_merge_evaluator,
                                                   logger=logger)

        logger.info("----Evaluating Finished----")
        return

    ### train
    while True:
        epoch += 1
        if 'MAX_EPOCH' in cfg.OPTIMIZE.keys():
            if epoch > cfg.OPTIMIZE.MAX_EPOCH:
                break

        ### train one epoch
        logger.info('----EPOCH {} Training----'.format(epoch))
        loss_acc = 0
        if rank == 0:
            pbar = tqdm(total=len(train_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(train_dataset_loader):
            # torch.cuda.empty_cache()
            torch.autograd.set_detect_anomaly(True)
            model.train()
            optimizer.zero_grad()
            inputs['i_iter'] = i_iter
            inputs['rank'] = rank
            ret_dict = model(inputs)

            if args.pretrained_ckpt is not None and not args.fix_semantic_instance:  # training offset
                if args.nofix:
                    loss = ret_dict['loss']
                elif len(ret_dict['offset_loss_list']) > 0:
                    loss = sum(ret_dict['offset_loss_list'])
                else:
                    loss = torch.tensor(0.0, requires_grad=True)  #mock pbar
                    ret_dict['offset_loss_list'] = [loss]  #mock writer
            elif args.pretrained_ckpt is not None and args.fix_semantic_instance and cfg.MODEL.NAME == 'PolarOffsetSpconvPytorchMeanshift':  # training dynamic shifting
                loss = sum(ret_dict['meanshift_loss'])
            elif cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                loss = sum(ret_dict['tracking_loss'])
                #########################
                # with open('./ipnb/{}_matching_list.pkl'.format(i_iter), 'wb') as fd:
                #     pickle.dump(ret_dict['matching_list'], fd)
                #########################
            else:
                loss = ret_dict['loss']
            loss.backward()
            optimizer.step()

            if rank == 0:
                try:
                    cur_lr = float(optimizer.lr)
                except:
                    cur_lr = optimizer.param_groups[0]['lr']
                loss_acc += loss.item()
                pbar.set_postfix({
                    'loss': loss.item(),
                    'lr': cur_lr,
                    'mean_loss': loss_acc / float(i_iter + 1)
                })
                pbar.update(1)
                writer.add_scalar('Train/01_Loss', ret_dict['loss'].item(),
                                  global_iter)
                writer.add_scalar('Train/02_SemLoss',
                                  ret_dict['sem_loss'].item(), global_iter)
                if 'offset_loss_list' in ret_dict and sum(
                        ret_dict['offset_loss_list']).item() > 0:
                    writer.add_scalar('Train/03_InsLoss',
                                      sum(ret_dict['offset_loss_list']).item(),
                                      global_iter)
                writer.add_scalar('Train/04_LR', cur_lr, global_iter)
                writer_acc = 5
                if 'meanshift_loss' in ret_dict:
                    writer.add_scalar('Train/05_DSLoss',
                                      sum(ret_dict['meanshift_loss']).item(),
                                      global_iter)
                    writer_acc += 1
                if 'tracking_loss' in ret_dict:
                    writer.add_scalar('Train/06_TRLoss',
                                      sum(ret_dict['tracking_loss']).item(),
                                      global_iter)
                    writer_acc += 1
                more_keys = []
                for k, _ in ret_dict.items():
                    if k.find('summary') != -1:
                        more_keys.append(k)
                for ki, k in enumerate(more_keys):
                    if k == 'bandwidth_weight_summary':
                        continue
                    ki += writer_acc
                    writer.add_scalar(
                        'Train/{}_{}'.format(str(ki).zfill(2), k), ret_dict[k],
                        global_iter)
                global_iter += 1
        if rank == 0:
            pbar.close()

        ### evaluate after each epoch
        logger.info('----EPOCH {} Evaluating----'.format(epoch))
        model.eval()
        min_points = 50
        before_merge_evaluator = init_eval(min_points=min_points)
        after_merge_evaluator = init_eval(min_points=min_points)
        tracking_loss = 0
        if rank == 0:
            vbar = tqdm(total=len(val_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(val_dataset_loader):
            # torch.cuda.empty_cache()
            inputs['i_iter'] = i_iter
            inputs['rank'] = rank
            with torch.no_grad():
                if cfg.MODEL.NAME.startswith(
                        'PolarOffsetSpconvPytorchMeanshiftTracking'
                ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                    ret_dict = model(inputs,
                                     is_test=True,
                                     merge_evaluator_list=None,
                                     merge_evaluator_window_k_list=None,
                                     require_cluster=True)
                else:
                    ret_dict = model(
                        inputs,
                        is_test=True,
                        before_merge_evaluator=before_merge_evaluator,
                        after_merge_evaluator=after_merge_evaluator,
                        require_cluster=True)
            if rank == 0:
                vbar.set_postfix({'loss': ret_dict['loss'].item()})
                vbar.update(1)
                writer.add_scalar('Val/01_Loss', ret_dict['loss'].item(),
                                  val_global_iter)
                writer.add_scalar('Val/02_SemLoss',
                                  ret_dict['sem_loss'].item(), val_global_iter)
                if 'offset_loss_list' in ret_dict and sum(
                        ret_dict['offset_loss_list']).item() > 0:
                    writer.add_scalar('Val/03_InsLoss',
                                      sum(ret_dict['offset_loss_list']).item(),
                                      val_global_iter)
                if 'tracking_loss' in ret_dict:
                    writer.add_scalar('Val/06_TRLoss',
                                      sum(ret_dict['tracking_loss']).item(),
                                      global_iter)
                    tracking_loss += sum(ret_dict['tracking_loss']).item()
                more_keys = []
                for k, _ in ret_dict.items():
                    if k.find('summary') != -1:
                        more_keys.append(k)
                for ki, k in enumerate(more_keys):
                    if k == 'bandwidth_weight_summary':
                        continue
                    ki += 4
                    writer.add_scalar('Val/{}_{}'.format(str(ki).zfill(2), k),
                                      ret_dict[k], val_global_iter)
                val_global_iter += 1
        tracking_loss /= len(val_dataset_loader)
        if dist_train:
            try:
                before_merge_evaluator = common_utils.merge_evaluator(
                    before_merge_evaluator, tmp_dir, prefix='before_')
                dist.barrier()
                after_merge_evaluator = common_utils.merge_evaluator(
                    after_merge_evaluator, tmp_dir, prefix='after_')
            except:
                print("Someting went wrong when merging evaluator in rank {}".
                      format(rank))
        if rank == 0:
            vbar.close()
        if rank == 0:
            ## print results
            logger.info("Before Merge Semantic Scores")
            before_merge_results = printResults(before_merge_evaluator,
                                                logger=logger,
                                                sem_only=True)
            logger.info("After Merge Panoptic Scores")
            after_merge_results = printResults(after_merge_evaluator,
                                               logger=logger)
            ## save ckpt
            other_state = {
                'best_before_iou': best_before_iou,
                'best_pq': best_pq,
                'best_after_iou': best_after_iou,
                'global_iter': global_iter,
                'val_global_iter': val_global_iter,
                'best_tracking_loss': best_tracking_loss,
            }
            saved_flag = False
            if best_tracking_loss > tracking_loss and cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                best_tracking_loss = tracking_loss
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir, 'checkpoint_epoch_{}_{}.pth'.format(
                                epoch,
                                str(tracking_loss)[:5])))
                    saved_flag = True
            if best_before_iou < before_merge_results['iou_mean']:
                best_before_iou = before_merge_results['iou_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir,
                            'checkpoint_epoch_{}_{}_{}_{}.pth'.format(
                                epoch,
                                str(best_before_iou)[:5],
                                str(best_pq)[:5],
                                str(best_after_iou)[:5])))
                    saved_flag = True
            if best_pq < after_merge_results['pq_mean']:
                best_pq = after_merge_results['pq_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir,
                            'checkpoint_epoch_{}_{}_{}_{}.pth'.format(
                                epoch,
                                str(best_before_iou)[:5],
                                str(best_pq)[:5],
                                str(best_after_iou)[:5])))
                    saved_flag = True
            if best_after_iou < after_merge_results['iou_mean']:
                best_after_iou = after_merge_results['iou_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir,
                            'checkpoint_epoch_{}_{}_{}_{}.pth'.format(
                                epoch,
                                str(best_before_iou)[:5],
                                str(best_pq)[:5],
                                str(best_after_iou)[:5])))
                    saved_flag = True
            logger.info("Current best before IoU: {}".format(best_before_iou))
            logger.info("Current best after IoU: {}".format(best_after_iou))
            logger.info("Current best after PQ: {}".format(best_pq))
            logger.info(
                "Current best tracking loss: {}".format(best_tracking_loss))
        if lr_scheduler != None:
            lr_scheduler.step(epoch)  # new feature
コード例 #6
0
def syn_task1(args, writer=None):
    print('Generating graph.')
    feature_generator = featgen.ConstFeatureGen(
        np.ones(args.input_dim, dtype=float))
    if args.dataset == 'syn1':
        gen_fn = gengraph.gen_syn1
    elif args.dataset == 'syn2':
        gen_fn = gengraph.gen_syn2
        feature_generator = None
    elif args.dataset == 'syn3':
        gen_fn = gengraph.gen_syn3
    elif args.dataset == 'syn4':
        gen_fn = gengraph.gen_syn4
    elif args.dataset == 'syn5':
        gen_fn = gengraph.gen_syn5
    G, labels, name = gen_fn(feature_generator=feature_generator)
    pyg_G = NxDataset([G],
                      device=torch.device('gpu' if args.gpu else 'cpu'))[0]
    num_classes = max(labels) + 1
    labels = torch.LongTensor(labels)
    print('Done generating graph.')

    model = GCNNet(args.input_dim,
                   args.hidden_dim,
                   args.output_dim,
                   num_classes,
                   args.num_gc_layers,
                   args=args)

    if args.gpu:
        model = model.cuda()

    train_ratio = args.train_ratio
    num_train = int(train_ratio * G.number_of_nodes())
    num_test = G.number_of_nodes() - num_train

    idx = [i for i in range(G.number_of_nodes())]

    np.random.shuffle(idx)
    train_mask = idx[:num_train]
    test_mask = idx[num_train:]

    loader = torch_geometric.data.DataLoader([pyg_G], batch_size=1)
    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler, opt = train_utils.build_optimizer(
        args, model.parameters(), weight_decay=args.weight_decay)
    for epoch in range(args.num_epochs):
        model.train()
        total_loss = 0
        for batch in loader:
            opt.zero_grad()
            pred = model(batch)

            pred = pred[train_mask]
            label = labels[train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            opt.step()
            total_loss += loss.item() * 1
        writer.add_scalar("loss", total_loss, epoch)

        if epoch % 10 == 0:
            test_acc = test(loader, model, args, labels, test_mask)
            print("{} {:.4f} {:.4f}".format(epoch, total_loss, test_acc))
            writer.add_scalar("test", test_acc, epoch)

    print("{} {:.4f} {:.4f}".format(epoch, total_loss, test_acc))
    data = gengraph.preprocess_input_graph(G, labels)
    adj = torch.tensor(data['adj'], dtype=torch.float)
    x = torch.tensor(data['feat'], requires_grad=True, dtype=torch.float)

    model.eval()
    ypred = model(batch)
コード例 #7
0
def main(_):
    logging.set_verbosity(tf.logging.INFO)

    assert os.path.isfile(FLAGS.pipeline_proto)

    g = tf.Graph()
    with g.as_default():
        pipeline_proto = load_pipeline_proto(FLAGS.pipeline_proto)
        logging.info("Pipeline configure: %s", '=' * 128)
        logging.info(pipeline_proto)

        train_config = pipeline_proto.train_config

        # Get examples from reader.
        examples, feed_init_fn = ads_mem_examples.get_examples(
            pipeline_proto.example_reader, split='train')

        # Build model for training.
        model = builder.build(pipeline_proto.model, is_training=True)
        predictions = model.build_inference_graph(examples)
        loss_dict = model.build_loss(predictions)

        model_init_fn = model.get_init_fn()
        uninitialized_variable_names = tf.report_uninitialized_variables()

        if FLAGS.restore_from:
            variables_to_restore = slim.get_variables_to_restore(
                exclude=[name for name in train_config.exclude_variable])
            restore_init_fn = slim.assign_from_checkpoint_fn(
                FLAGS.restore_from, variables_to_restore)

        def init_fn(sess):
            model_init_fn(sess)
            if FLAGS.restore_from:
                restore_init_fn(sess)

        # Loss and optimizer.
        for loss_name, loss_tensor in loss_dict.iteritems():
            tf.losses.add_loss(loss_tensor)
            tf.summary.scalar('losses/{}'.format(loss_name), loss_tensor)
        total_loss = tf.losses.get_total_loss()
        tf.summary.scalar('losses/total_loss', total_loss)

        for reg_loss in tf.losses.get_regularization_losses():
            name = 'losses/reg_loss_{}'.format(reg_loss.op.name.split('/')[0])
            tf.summary.scalar(name, reg_loss)

        optimizer = train_utils.build_optimizer(train_config)
        if train_config.moving_average:
            optimizer = tf.contrib.opt.MovingAverageOptimizer(
                optimizer, average_decay=0.99)

        gradient_multipliers = train_utils.build_multipler(
            train_config.gradient_multiplier)

        variables_to_train = model.get_variables_to_train()
        logging.info('=' * 128)
        for var in variables_to_train:
            logging.info(var)
        train_op = slim.learning.create_train_op(
            total_loss,
            variables_to_train=variables_to_train,
            clip_gradient_norm=0.0,
            gradient_multipliers=gradient_multipliers,
            summarize_gradients=True,
            optimizer=optimizer)

        saver = None
        if train_config.moving_average:
            saver = optimizer.swapping_saver()

    # Start checking.
    logging.info('Start checking...')
    session_config = train_utils.default_session_config(
        FLAGS.per_process_gpu_memory_fraction)

    def _session_wrapper_fn(sess):
        feed_init_fn(sess)
        return sess

    slim.learning.train(train_op,
                        logdir=FLAGS.train_log_dir,
                        graph=g,
                        master='',
                        is_chief=True,
                        number_of_steps=train_config.number_of_steps,
                        log_every_n_steps=train_config.log_every_n_steps,
                        save_interval_secs=train_config.save_interval_secs,
                        save_summaries_secs=train_config.save_summaries_secs,
                        session_config=session_config,
                        session_wrapper=_session_wrapper_fn,
                        init_fn=init_fn,
                        saver=saver)

    logging.info('Done')
コード例 #8
0
def medic(args):
    """
    Creating a simple Graph ConvNet using parameters of args (https://arxiv.org/abs/1609.02907)
    """

    # Loading DataSet from /Pickles
    global result_test, result_train
    with open('Pickles/feats.pickle', 'rb') as handle:
        feats = np.expand_dims(pickle.load(handle), axis=0)
    with open('Pickles/age_adj.pickle', 'rb') as handle:
        age_adj = pickle.load(handle)
    with open('Pickles/preds.pickle', 'rb') as handle:
        labels = np.expand_dims(pickle.load(handle), axis=0)

    # initializing model variables
    num_nodes = labels.shape[1]
    num_train = int(num_nodes * 0.9)
    num_classes = max(labels[0]) + 1
    idx = [i for i in range(num_nodes)]
    np.random.shuffle(idx)
    train_idx = idx[:num_train]
    test_idx = idx[num_train:]

    labels = labels.astype(np.long)
    age_adj = age_adj.astype(np.float)
    feats = feats.astype(np.float)

    age_adj = age_adj + np.eye(age_adj.shape[0])
    d_hat_inv = np.linalg.inv(np.diag(age_adj.sum(axis=1)))**(1 / 2)
    temp = np.matmul(d_hat_inv, age_adj)
    age_adj = np.matmul(temp, d_hat_inv)
    age_adj = np.expand_dims(age_adj, axis=0)

    labels_train = torch.tensor(labels[:, train_idx], dtype=torch.long)
    adj = torch.tensor(age_adj, dtype=torch.float)
    x = torch.tensor(feats, dtype=torch.float, requires_grad=True)

    # Creating a model which is used in https://github.com/RexYing/gnn-model-explainer
    model = models.GcnEncoderNode(
        args.input_dim,
        args.hidden_dim,
        args.output_dim,
        num_classes,
        args.num_gc_layers,
        bn=args.bn,
        args=args,
    )

    if args.gpu:
        model = model.cuda()

    scheduler, optimizer = build_optimizer(args,
                                           model.parameters(),
                                           weight_decay=args.weight_decay)
    model.train()
    to_save = (0, None)  # used for saving best model

    # training the model
    for epoch in range(args.num_epochs):
        begin_time = time.time()
        model.zero_grad()

        if args.gpu:
            ypred, adj_att = model(x.cuda(), adj.cuda())
        else:
            ypred, adj_att = model(x, adj)
        ypred_train = ypred[:, train_idx, :]
        if args.gpu:
            loss = model.loss(ypred_train, labels_train.cuda())
        else:
            loss = model.loss(ypred_train, labels_train)
        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.clip)

        optimizer.step()
        # for param_group in optimizer.param_groups:
        #    print(param_group["lr"])
        elapsed = time.time() - begin_time

        result_train, result_test = evaluate_node(ypred.cpu(), labels,
                                                  train_idx, test_idx)

        if result_test["acc"] > to_save[0]:
            to_save = (result_test["acc"], (model, optimizer, args))

        if epoch % 10 == 0:
            print(
                "epoch: ",
                epoch,
                "; loss: ",
                loss.item(),
                "; train_acc: ",
                result_train["acc"],
                "; test_acc: ",
                result_test["acc"],
                "; train_prec: ",
                result_train["prec"],
                "; test_prec: ",
                result_test["prec"],
                "; epoch time: ",
                "{0:0.2f}".format(elapsed),
            )
        if epoch % 100 == 0:
            print(result_train["conf_mat"])
            print(result_test["conf_mat"])

        if scheduler is not None:
            scheduler.step()

    print(result_train["conf_mat"])
    print(result_test["conf_mat"])

    to_save[1][0].eval()
    if args.gpu:
        ypred, _ = to_save[1][0](x.cuda(), adj.cuda())
    else:
        ypred, _ = to_save[1][0](x, adj)
    cg_data = {
        "adj": age_adj,
        "feat": feats,
        "label": labels,
        "pred": ypred.cpu().detach().numpy(),
        "train_idx": train_idx,
    }

    # saving the model so that it can be restored for GNN explaining
    print(
        save_checkpoint(to_save[1][0],
                        to_save[1][1],
                        args,
                        num_epochs=-1,
                        cg_dict=cg_data))

    return to_save[1][0], to_save[1][1], args, cg_data
コード例 #9
0
def main(_):
  logging.set_verbosity(tf.logging.INFO)

  assert os.path.isfile(FLAGS.pipeline_proto), FLAGS.pipeline_proto

  g = tf.Graph()
  with g.as_default():
    pipeline_proto = load_pipeline_proto(FLAGS.pipeline_proto)
    logging.info("Pipeline configure: %s", '=' * 128)
    logging.info(pipeline_proto)

    train_config = pipeline_proto.train_config

    # Get examples from reader.
    examples = ads_examples.get_examples(pipeline_proto.example_reader)

    # Build model for training.
    global_step = slim.get_or_create_global_step()

    model = builder.build(pipeline_proto.model, is_training=True)
    predictions = model.build_inference(examples)
    loss_dict = model.build_loss(predictions)

    init_fn = model.get_init_fn()
    uninitialized_variable_names = tf.report_uninitialized_variables()

    # Loss and optimizer.
    for loss_name, loss_tensor in loss_dict.iteritems():
      tf.losses.add_loss(loss_tensor)
      tf.summary.scalar('losses/{}'.format(loss_name), loss_tensor)
    total_loss = tf.losses.get_total_loss()
    tf.summary.scalar('losses/total_loss', total_loss)

    optimizer = train_utils.build_optimizer(train_config)
    if train_config.moving_average:
      optimizer = tf.contrib.opt.MovingAverageOptimizer(optimizer,
          average_decay=0.99)

    gradient_multipliers = train_utils.build_multipler(
        train_config.gradient_multiplier)

    variables_to_train = model.get_variables_to_train()
    for var in variables_to_train:
      logging.info(var)

    train_op = slim.learning.create_train_op(total_loss,
        variables_to_train=variables_to_train, 
        clip_gradient_norm=5.0,
        gradient_multipliers=gradient_multipliers,
        summarize_gradients=True,
        optimizer=optimizer)

    saver = None
    if train_config.moving_average:
      saver = optimizer.swapping_saver()

  # Starts training.
  logging.info('Start training.')

  session_config = train_utils.default_session_config( 
      FLAGS.per_process_gpu_memory_fraction)
  slim.learning.train(train_op, 
      logdir=FLAGS.train_log_dir,
      graph=g,
      master='',
      is_chief=True,
      number_of_steps=train_config.number_of_steps,
      log_every_n_steps=train_config.log_every_n_steps,
      save_interval_secs=train_config.save_interval_secs,
      save_summaries_secs=train_config.save_summaries_secs,
      session_config=session_config,
      init_fn=init_fn,
      saver=saver)

  logging.info('Done')