コード例 #1
0
ファイル: main.py プロジェクト: wangx404/Center_Loss_in_MXNet
def train():
    """
    train model using softmax loss or softmax loss/center loss.
    训练模型。
    """
    print("Start to train...")
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)

    train_iter, test_iter = data_loader(args.batch_size)
    ctx = mx.gpu() if args.use_gpu else mx.cpu()

    # main model (LeNetPlus), loss, trainer
    model = LeNetPlus(classes=args.num_classes, feature_size=args.feature_size)
    model.hybridize()
    model.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    network_trainer = gluon.Trainer(model.collect_params(),
                                    optimizer="sgd",
                                    optimizer_params={
                                        "learning_rate": args.lr,
                                        "wd": args.wd
                                    })  #  "momentum": 0.9,
    # center loss network and trainer
    if args.center_loss:
        center_loss = CenterLoss(num_classes=args.num_classes,
                                 feature_size=args.feature_size,
                                 lmbd=args.lmbd,
                                 ctx=ctx)
        center_loss.initialize(mx.init.Xavier(magnitude=2.24),
                               ctx=ctx)  # 包含了一个center矩阵,因此需要进行初始化
        center_trainer = gluon.Trainer(
            center_loss.collect_params(),
            optimizer="sgd",
            optimizer_params={"learning_rate": args.alpha})
    else:
        center_loss, center_trainer = None, None

    smoothing_constant, moving_loss = .01, 0.0
    best_acc = 0.0
    for epoch in range(args.epochs):
        # using learning rate decay during training process
        if (epoch > 0) and (epoch % args.lr_step == 0):
            network_trainer.set_learning_rate(network_trainer.learning_rate *
                                              args.lr_factor)
            if args.center_loss:
                center_trainer.set_learning_rate(center_trainer.learning_rate *
                                                 args.lr_factor)

        start_time = time.time()
        for i, (data, label) in enumerate(train_iter):
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)
            with autograd.record():
                output, features = model(data)
                loss_softmax = softmax_cross_entropy(output, label)
                # cumpute loss according to user"s choice
                if args.center_loss:
                    loss_center = center_loss(features, label)
                    loss = loss_softmax + loss_center
                else:
                    loss = loss_softmax

            # update 更新参数
            loss.backward()
            network_trainer.step(args.batch_size)
            if args.center_loss:
                center_trainer.step(args.batch_size)

            # calculate smoothed loss value 平滑损失
            curr_loss = nd.mean(loss).asscalar()
            moving_loss = (curr_loss if ((i == 0) and (epoch == 0)) else
                           (1 - smoothing_constant) * moving_loss +
                           smoothing_constant * curr_loss)  # 累计加权函数

        # training cost time 训练耗时
        elapsed_time = time.time() - start_time
        train_accuracy, train_ft, _, train_lb = evaluate_accuracy(
            train_iter, model, center_loss, args.eval_method, ctx)
        test_accuracy, test_ft, _, test_lb = evaluate_accuracy(
            test_iter, model, center_loss, args.eval_method, ctx)

        # draw feature map 绘制特征图像
        if args.plotting:
            plot_features(train_ft,
                          train_lb,
                          num_classes=args.num_classes,
                          fpath=os.path.join(
                              args.out_dir,
                              "%s-train-epoch-%d.png" % (args.prefix, epoch)))
            plot_features(test_ft,
                          test_lb,
                          num_classes=args.num_classes,
                          fpath=os.path.join(
                              args.out_dir,
                              "%s-test-epoch-%d.png" % (args.prefix, epoch)))

        logging.warning(
            "Epoch [%d]: Loss=%f, Train-Acc=%f, Test-Acc=%f, Epoch-time=%f" %
            (epoch, moving_loss, train_accuracy, test_accuracy, elapsed_time))

        # save model parameters with the highest accuracy 保存accuracy最高的model参数
        if test_accuracy > best_acc:
            best_acc = test_accuracy
            model.save_parameters(
                os.path.join(args.ckpt_dir, args.prefix + "-best.params"))
            # 因为CenterLoss继承自gluon.HyperBlock,所以具有普通模型相关的对象可供调用,即可使用save_parameters/load_parameters进行参数的保存和加载。
            # 如果CenterLoss没有直接父类,那么就需要通过CenterLoss.embedding.weight.data/set_data进行数据的保存和加载。
            center_loss.save_parameters(
                os.path.join(args.ckpt_dir,
                             args.prefix + "-feature_matrix.params"))
コード例 #2
0
def train():
    print('Start to train...')
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    ctx = [mx.gpu(int(i))
           for i in args.gpus.split(',')] if args.gpus != '-1' else mx.cpu()
    print('Loading the data...')

    train_iter, test_iter = data_loader(args.batch_size)

    model = LeNetPlus()
    model.hybridize()
    model.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)

    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

    trainer = gluon.Trainer(model.collect_params(),
                            optimizer='sgd',
                            optimizer_params={
                                'learning_rate': args.lr,
                                'wd': args.wd
                            })

    if args.center_loss:
        center_loss = CenterLoss(args.num_classes,
                                 feature_size=2,
                                 lmbd=args.lmbd)
        center_loss.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
        trainer_center = gluon.Trainer(
            center_loss.collect_params(),
            optimizer='sgd',
            optimizer_params={'learning_rate': args.alpha})
    else:
        center_loss, trainer_center = None, None

    smoothing_constant, moving_loss = .01, 0.0

    best_acc = 0
    for e in range(args.epochs):
        start_time = timeit.default_timer()

        for i, (data, label) in enumerate(train_iter):
            data = data.as_in_context(ctx[0])
            label = label.as_in_context(ctx[0])
            with autograd.record():
                output, features = model(data)
                loss_softmax = softmax_cross_entropy(output, label)
                if args.center_loss:
                    loss_center = center_loss(features, label)
                    loss = loss_softmax + loss_center
                else:
                    loss = loss_softmax
            loss.backward()
            trainer.step(data.shape[0])
            if args.center_loss:
                trainer_center.step(data.shape[0])

            curr_loss = nd.mean(loss).asscalar()
            moving_loss = (curr_loss if ((i == 0) and (e == 0)) else
                           (1 - smoothing_constant) * moving_loss +
                           smoothing_constant * curr_loss)

        elapsed_time = timeit.default_timer() - start_time

        train_accuracy, train_ft, _, train_lb = evaluate_accuracy(
            train_iter, model, ctx)
        test_accuracy, test_ft, _, test_lb = evaluate_accuracy(
            test_iter, model, ctx)

        if args.plotting:
            plot_features(train_ft,
                          train_lb,
                          num_classes=args.num_classes,
                          fpath=os.path.join(
                              args.out_dir,
                              '%s-train-epoch-%s.png' % (args.prefix, e)))
            plot_features(test_ft,
                          test_lb,
                          num_classes=args.num_classes,
                          fpath=os.path.join(
                              args.out_dir,
                              '%s-test-epoch-%s.png' % (args.prefix, e)))

        logging.warning("Epoch [%d]: Loss=%f" % (e, moving_loss))
        logging.warning("Epoch [%d]: Train-Acc=%f" % (e, train_accuracy))
        logging.warning("Epoch [%d]: Test-Acc=%f" % (e, test_accuracy))
        logging.warning("Epoch [%d]: Elapsed-time=%f" % (e, elapsed_time))

        if test_accuracy > best_acc:
            best_acc = test_accuracy
            model.save_params(
                os.path.join(args.ckpt_dir, args.prefix + '-best.params'))