Ejemplo n.º 1
0
def test_add_text():
    # this will generate an event file under _LOGDIR and
    # a json file called tensors.json under _LOGDIR/plugins/tensorboard_text/tensors.json
    sw = SummaryWriter(logdir=_LOGDIR)
    sw.add_text(tag='test_add_text', text='Hello MXNet!')
    sw.close()
    check_and_remove_logdir_for_text()
Ejemplo n.º 2
0
def test_add_scalar():
    sw = SummaryWriter(logdir=_LOGDIR)
    sw.add_scalar(tag='test_add_scalar',
                  value=np.random.uniform(),
                  global_step=0)
    sw.close()
    check_event_file_and_remove_logdir()
Ejemplo n.º 3
0
def test(ctx=mx.cpu()):
    from mxboard import SummaryWriter
    sw = SummaryWriter(logdir='sphere_dynamic', flush_secs=5)

    net = nn.Sequential()
    b1 = base_net(48,
                  3,
                  fun=special_conv,
                  kernel_size=(3, 3),
                  same_shape=False)
    b2 = base_net(1,
                  48,
                  fun=special_conv,
                  kernel_size=(3, 3),
                  same_shape=False)
    fc = nn.Dense(3, in_units=9)
    net.add(b1, b2, fc)
    init_s(net, ctx)

    from mxnet import gluon, autograd
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': 0.01})
    for i in range(10000):
        with autograd.record():
            out = net(img)
            loss = nd.sum(nd.abs(out - target))
        loss.backward()
        trainer.step(2)
        sw.add_scalar(tag='loss', value=loss.asscalar(), global_step=i)
        if i % 100 == 0:
            print i, loss.asscalar()
    sw.close()
Ejemplo n.º 4
0
def test_add_audio():
    shape = (100,)
    data = mx.nd.random.uniform(-1, 1, shape=shape)
    sw = SummaryWriter(logdir=_LOGDIR)
    sw.add_audio(tag='test_add_audio', audio=data)
    sw.close()
    check_event_file_and_remove_logdir()
Ejemplo n.º 5
0
def test_add_image():
    shape = list(rand_shape_nd(4))
    shape[1] = 3
    shape = tuple(shape)
    sw = SummaryWriter(logdir=_LOGDIR)
    sw.add_image(tag='test_add_image', image=mx.nd.random.normal(shape=shape), global_step=0)
    sw.close()
    check_event_file_and_remove_logdir()
Ejemplo n.º 6
0
 def check_add_histogram(data):
     sw = SummaryWriter(logdir=_LOGDIR)
     sw.add_histogram(tag='test_add_histogram',
                      values=data,
                      global_step=0,
                      bins=100)
     sw.close()
     check_event_file_and_remove_logdir()
Ejemplo n.º 7
0
    def test_evaluate(self):
        from model import hybrid_model
        from model import trainer
        from data_loader.data_utils import data_gen
        import numpy as np
        from mxboard import SummaryWriter
        import os
        import shutil

        ctx = mx.gpu(1)
        num_of_vertices = 897
        batch_size = 50

        PeMS_dataset = data_gen('datasets/PeMSD7_V_897.csv', 24)
        print('>> Loading dataset with Mean: {0:.2f}, STD: {1:.2f}'.format(
            PeMS_dataset.mean, PeMS_dataset.std))

        test = PeMS_dataset['test'].transpose((0, 3, 1, 2))
        test_x, test_y = test[:100, :, :12, :], test[:100, :, 12:, :]
        test_loader = gluon.data.DataLoader(gluon.data.ArrayDataset(
            nd.array(test_x), nd.array(test_y)),
                                            batch_size=batch_size,
                                            shuffle=False)
        print(test_x.shape, test_y.shape)

        cheb_polys = nd.random_uniform(shape=(num_of_vertices,
                                              num_of_vertices * 3))
        blocks = [[1, 32, 64], [64, 32, 128]]
        x = nd.random_uniform(shape=(batch_size, 1, 12, num_of_vertices),
                              ctx=ctx)

        net = hybrid_model.STGCN(12, 3, 3, blocks, 1.0, num_of_vertices,
                                 cheb_polys)
        net.initialize(ctx=ctx)
        net.hybridize()
        net(x)

        ground_truth = (
            np.concatenate([y.asnumpy() for x, y in test_loader], axis=0) *
            PeMS_dataset.std + PeMS_dataset.mean)[:100]

        if os.path.exists('test_logs'):
            shutil.rmtree('test_logs')
        sw = SummaryWriter('test_logs', flush_secs=5)

        trainer.evaluate(net, ctx, ground_truth, test_loader, 12,
                         PeMS_dataset.mean, PeMS_dataset.std, sw, 0)
        self.assertEqual(os.path.exists('test_logs'), True)
        sw.close()
        if os.path.exists('test_logs'):
            shutil.rmtree('test_logs')
Ejemplo n.º 8
0
class Logger:
    """
    mxboard for mxnet
    """
    def __init__(self, config):
        self.config = config
        self.train_summary_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), "logs", "train")
        self.validate_summary_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), "logs", "val")
        if not os.path.exists(self.train_summary_dir):
            os.makedirs(self.train_summary_dir)
        if not os.path.exists(self.validate_summary_dir):
            os.makedirs(self.validate_summary_dir)
        self.train_summary_writer = SummaryWriter(self.train_summary_dir)
        self.validate_summary_writer = SummaryWriter(self.validate_summary_dir)

    # it can summarize scalars and images.
    def data_summarize(self, step, summarizer="train", summaries_dict=None):
        """
        :param step: the step of the summary
        :param summarizer: use the train summary writer or the validate one
        :param summaries_dict: the dict of the summaries values (tag,value)
        :return:
        """
        summary_writer = self.train_summary_writer if summarizer == "train" else self.validate_summary_writer
        if summaries_dict is not None:
            # summary_writer.add_scalars('./', summaries_dict, step)
            for tag, value in summaries_dict.items():
                summary_writer.add_scalar(tag=tag,
                                          value=value,
                                          global_step=step)
            summary_writer.flush()
            # summary = tf.Summary()
            # for tag, value in summaries_dict.items():
            #     summary.value.add(tag=tag, simple_value=value)
            # summary_writer.add_summary(summary, step)
            # summary_writer.flush()

    def graph_summary(self, net, summarizer="train"):
        summary_writer = self.train_summary_writer if summarizer == "train" else self.validate_summary_writer
        input_to_model = mxnet.ndarray.ones(
            shape=(1, self.config['num_channels'], self.config['img_height'],
                   self.config['img_width']),
            dtype='float32')
        summary_writer.add_graph(net, (input_to_model, ))

    def close(self):
        self.train_summary_writer.close()
        self.validate_summary_writer.close()
Ejemplo n.º 9
0
 def check_add_image(data):
     sw = SummaryWriter(logdir=_LOGDIR)
     sw.add_image(tag='test_add_image', image=data, global_step=0)
     sw.close()
     check_event_file_and_remove_logdir()
def train_net(args, ctx, pretrained, pretrained_flow, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
    sw = SummaryWriter(logdir=config.output_path, flush_secs=5)
    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_train_symbol(config)

    sw.add_graph(sym)

    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.val_image_set.split('+')]
    roidbs = [load_gt_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                            flip=config.TRAIN.FLIP)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    # load training data
    train_data = AnchorLoader(feat_sym, roidb, config, batch_size=input_batch_size, shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                              normalize_target=config.network.NORMALIZE_RPN, bbox_mean=config.network.ANCHOR_MEANS,
                              bbox_std=config.network.ANCHOR_STDS)

    roidbs_eval = [
        load_gt_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                      flip=False)
        for image_set in image_sets]
    roidb_eval = merge_roidb(roidbs_eval)
    # need?
    roidb_eval = filter_roidb(roidb_eval, config)
    eval_data = AnchorLoader(feat_sym, roidb_eval, config, batch_size=input_batch_size, shuffle=config.TRAIN.SHUFFLE,
                             ctx=ctx,
                             feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                             anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                             normalize_target=config.network.NORMALIZE_RPN, bbox_mean=config.network.ANCHOR_MEANS,
                             bbox_std=config.network.ANCHOR_STDS)

    # infer max shape
    max_data_shape = [('data', (
        config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES]))),
                      ('data_ref', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]),
                                    max([v[1] for v in config.SCALES]))),
                      ('eq_flag', (1,))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        arg_params_flow, aux_params_flow = load_param(pretrained_flow, epoch, convert=True)
        arg_params.update(arg_params_flow)
        aux_params.update(aux_params_flow)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, max_data_shapes=[max_data_shape for _ in range(batch_size)],
                        max_label_shapes=[max_label_shape for _ in range(batch_size)],
                        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = [callback.Speedometer(train_data.batch_size, frequent=args.frequent, sw=sw),
                          callback.SummaryMetric(sw, frequent=args.frequent, prefix='train')]
    eval_end_callback = callback.SummaryValMetric(sw, prefix='val')
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
                          callback.do_checkpoint(prefix, means, stds)]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step, sw=sw)
    # optimizer
    optimizer_params = {'momentum': config.TRAIN.momentum,
                        'wd': config.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)
    if not isinstance(eval_data, PrefetchingIter):
        eval_data = PrefetchingIter(eval_data)

    # train
    mod.fit(train_data, eval_data=None, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, eval_end_callback=eval_end_callback, kvstore=config.default.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,eval_num_batch=config.TEST.EVAL_NUM_BATCH,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)

    sw.close()
Ejemplo n.º 11
0
def model_train(blocks, args, dataset, cheb_polys, ctx, logdir='./logdir'):
    '''
    Parameters
    ----------
    blocks: list[list], model structure, e.g. [[1, 32, 64], [64, 32, 128]]

    args: argparse.Namespace

    dataset: Dataset

    cheb_polys: mx.ndarray,
                shape is (num_of_vertices, order_of_cheb * num_of_vertices)

    ctx: mx.context.Context

    logdir: str, path of mxboard logdir

    '''

    num_of_vertices = args.num_of_vertices
    n_his, n_pred = args.n_his, args.n_pred
    order_of_cheb, Kt = args.order_of_cheb, args.kt
    batch_size, epochs = args.batch_size, args.epochs
    opt = args.opt
    keep_prob = args.keep_prob

    # data
    train = dataset['train'].transpose((0, 3, 1, 2))
    val = dataset['val'].transpose((0, 3, 1, 2))
    test = dataset['test'].transpose((0, 3, 1, 2))

    train_x, train_y = train[:, :, :n_his, :], train[:, :, n_his:, :]
    val_x, val_y = val[:, :, :n_his, :], val[:, :, n_his:, :]
    test_x, test_y = test[:, :, :n_his, :], test[:, :, n_his:, :]

    print(train_x.shape, train_y.shape, val_x.shape, val_y.shape, test_x.shape,
          test_y.shape)

    train_loader = gluon.data.DataLoader(gluon.data.ArrayDataset(
        nd.array(train_x), nd.array(train_y)),
                                         batch_size=batch_size,
                                         shuffle=False)
    val_loader = gluon.data.DataLoader(gluon.data.ArrayDataset(
        nd.array(val_x), nd.array(val_y)),
                                       batch_size=batch_size,
                                       shuffle=False)
    test_loader = gluon.data.DataLoader(gluon.data.ArrayDataset(
        nd.array(test_x), nd.array(test_y)),
                                        batch_size=batch_size,
                                        shuffle=False)

    ground_truth = (
        np.concatenate([y.asnumpy()
                        for x, y in test_loader], axis=0) * dataset.std +
        dataset.mean)

    # model
    model = hybrid_model.STGCN(n_his=n_his,
                               order_of_cheb=order_of_cheb,
                               Kt=Kt,
                               blocks=blocks,
                               keep_prob=keep_prob,
                               num_of_vertices=num_of_vertices,
                               cheb_polys=cheb_polys)
    model.initialize(ctx=ctx, init=mx.init.Xavier())
    model.hybridize()

    # loss function
    loss = gluon.loss.L2Loss()

    # trainer
    trainer = gluon.Trainer(model.collect_params(), args.opt)
    trainer.set_learning_rate(args.lr)

    if not os.path.exists('params'):
        os.mkdir('params')

    sw = SummaryWriter(logdir=logdir, flush_secs=5)
    train_step = 0
    val_step = 0

    for epoch in range(epochs):
        start_time = time.time()
        for x, y in train_loader:
            tmp = nd.concat(x, y, dim=2)
            for pred_idx in range(n_pred):
                end_idx = pred_idx + n_his
                x_ = tmp[:, :, pred_idx:end_idx, :]
                y_ = tmp[:, :, end_idx:end_idx + 1, :]
                with autograd.record():
                    l = loss(model(x_.as_in_context(ctx)),
                             y_.as_in_context(ctx))
                l.backward()
                sw.add_scalar(tag='training_loss',
                              value=l.mean().asscalar(),
                              global_step=train_step)
                trainer.step(x.shape[0])
                train_step += 1

        val_loss_list = []
        for x, y in val_loader:
            pred = predict_batch(model, ctx, x, n_pred)
            val_loss_list.append(loss(pred, y).mean().asscalar())
        sw.add_scalar(tag='val_loss',
                      value=sum(val_loss_list) / len(val_loss_list),
                      global_step=val_step)

        evaluate(model, ctx, ground_truth, test_loader, n_pred, dataset.mean,
                 dataset.std, sw, val_step)
        val_step += 1

        if (epoch + 1) % args.save == 0:
            model.save_parameters('params/{}.params'.format(epoch + 1))

    sw.close()
Ejemplo n.º 12
0
def train_net(net, config, check_flag, logger, sig_state, sig_pgbar, sig_table):
    print(config)
    # config = Configs()
    # matplotlib.use('Agg')
    # import matplotlib.pyplot as plt
    sig_pgbar.emit(-1)
    mx.random.seed(1)
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    classes = 10
    num_epochs = config.train_cfg.epoch
    batch_size = config.train_cfg.batchsize
    optimizer = config.lr_cfg.optimizer
    lr = config.lr_cfg.lr
    num_gpus = config.train_cfg.gpu
    batch_size *= max(1, num_gpus)
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = config.data_cfg.worker

    warmup = config.lr_cfg.warmup
    if config.lr_cfg.decay == 'cosine':
        lr_sch = lr_scheduler.CosineScheduler((50000//batch_size)*num_epochs,
                                              base_lr=lr,
                                              warmup_steps=warmup *
                                              (50000//batch_size),
                                              final_lr=1e-5)
    else:
        lr_sch = lr_scheduler.FactorScheduler((50000//batch_size)*config.lr_cfg.factor_epoch,
                                              factor=config.lr_cfg.factor,
                                              base_lr=lr,
                                              warmup_steps=warmup*(50000//batch_size))

    model_name = config.net_cfg.name

    if config.data_cfg.mixup:
        model_name += '_mixup'
    if config.train_cfg.amp:
        model_name += '_amp'

    base_dir = './'+model_name
    if os.path.exists(base_dir):
        base_dir = base_dir + '-' + \
            time.strftime("%m-%d-%H.%M.%S", time.localtime())
    makedirs(base_dir)

    if config.save_cfg.tensorboard:
        logdir = base_dir+'/tb/'+model_name
        if os.path.exists(logdir):
            logdir = logdir + '-' + \
                time.strftime("%m-%d-%H.%M.%S", time.localtime())
        sw = SummaryWriter(logdir=logdir, flush_secs=5, verbose=False)
        cmd_file = open(base_dir+'/tb.bat', mode='w')
        cmd_file.write('tensorboard --logdir=./')
        cmd_file.close()

    save_period = 10
    save_dir = base_dir+'/'+'params'
    makedirs(save_dir)

    plot_name = base_dir+'/'+'plot'
    makedirs(plot_name)

    stat_name = base_dir+'/'+'stat.txt'

    csv_name = base_dir+'/'+'data.csv'
    if os.path.exists(csv_name):
        csv_name = base_dir+'/'+'data-' + \
            time.strftime("%m-%d-%H.%M.%S", time.localtime())+'.csv'
    csv_file = open(csv_name, mode='w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(['Epoch', 'train_loss', 'train_acc',
                         'valid_loss', 'valid_acc', 'lr', 'time'])

    logging_handlers = [logging.StreamHandler(), logger]
    logging_handlers.append(logging.FileHandler(
        '%s/train_cifar10_%s.log' % (model_name, model_name)))

    logging.basicConfig(level=logging.INFO, handlers=logging_handlers)
    logging.info(config)

    if config.train_cfg.amp:
        amp.init()

    if config.save_cfg.profiler:
        profiler.set_config(profile_all=True,
                            aggregate_stats=True,
                            continuous_dump=True,
                            filename=base_dir+'/%s_profile.json' % model_name)
        is_profiler_run = False

    trans_list = []
    imgsize = config.data_cfg.size
    if config.data_cfg.crop:
        trans_list.append(gcv_transforms.RandomCrop(
            32, pad=config.data_cfg.crop_pad))
    if config.data_cfg.cutout:
        trans_list.append(CutOut(config.data_cfg.cutout_size))
    if config.data_cfg.flip:
        trans_list.append(transforms.RandomFlipLeftRight())
    if config.data_cfg.erase:
        trans_list.append(gcv_transforms.block.RandomErasing(s_max=0.25))
    trans_list.append(transforms.Resize(imgsize))
    trans_list.append(transforms.ToTensor())
    trans_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465],
                                           [0.2023, 0.1994, 0.2010]))

    transform_train = transforms.Compose(trans_list)

    transform_test = transforms.Compose([
        transforms.Resize(imgsize),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    def label_transform(label, classes):
        ind = label.astype('int')
        res = nd.zeros((ind.shape[0], classes), ctx=label.context)
        res[nd.arange(ind.shape[0], ctx=label.context), ind] = 1
        return res

    def test(ctx, val_data):
        metric = mx.metric.Accuracy()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
        num_batch = len(val_data)
        test_loss = 0
        for i, batch in enumerate(val_data):
            data = gluon.utils.split_and_load(
                batch[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(
                batch[1], ctx_list=ctx, batch_axis=0)
            outputs = [net(X) for X in data]
            loss = [loss_fn(yhat, y) for yhat, y in zip(outputs, label)]
            metric.update(label, outputs)
            test_loss += sum([l.sum().asscalar() for l in loss])
        test_loss /= batch_size * num_batch
        name, val_acc = metric.get()
        return name, val_acc, test_loss

    def train(epochs, ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if config.train_cfg.param_init:
            init_func = getattr(mx.init, config.train_cfg.init)
            net.initialize(init_func(), ctx=ctx, force_reinit=True)
        else:
            net.load_parameters(config.train_cfg.param_file, ctx=ctx)

        summary(net, stat_name, nd.uniform(
            shape=(1, 3, imgsize, imgsize), ctx=ctx[0]))
        # net = nn.HybridBlock()
        net.hybridize()

        root = config.dir_cfg.dataset
        train_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=True).transform_first(transform_train),
            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)

        val_data = gluon.data.DataLoader(
            gluon.data.vision.CIFAR10(
                root=root, train=False).transform_first(transform_test),
            batch_size=batch_size, shuffle=False, num_workers=num_workers)

        trainer_arg = {'learning_rate': config.lr_cfg.lr,
                       'wd': config.lr_cfg.wd, 'lr_scheduler': lr_sch}
        extra_arg = eval(config.lr_cfg.extra_arg)
        trainer_arg.update(extra_arg)
        trainer = gluon.Trainer(net.collect_params(), optimizer, trainer_arg)
        if config.train_cfg.amp:
            amp.init_trainer(trainer)
        metric = mx.metric.Accuracy()
        train_metric = mx.metric.RMSE()
        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=False if config.data_cfg.mixup else True)
        train_history = TrainingHistory(['training-error', 'validation-error'])
        # acc_history = TrainingHistory(['training-acc', 'validation-acc'])
        loss_history = TrainingHistory(['training-loss', 'validation-loss'])

        iteration = 0

        best_val_score = 0

        # print('start training')
        sig_state.emit(1)
        sig_pgbar.emit(0)
        # signal.emit('Training')
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            metric.reset()
            train_loss = 0
            num_batch = len(train_data)
            alpha = 1
            for i, batch in enumerate(train_data):
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    profiler.set_state('run')
                    is_profiler_run = True
                if epoch == 0 and iteration == 1 and config.save_cfg.tensorboard:
                    sw.add_graph(net)
                lam = np.random.beta(alpha, alpha)
                if epoch >= epochs - 20 or not config.data_cfg.mixup:
                    lam = 1

                data_1 = gluon.utils.split_and_load(
                    batch[0], ctx_list=ctx, batch_axis=0)
                label_1 = gluon.utils.split_and_load(
                    batch[1], ctx_list=ctx, batch_axis=0)

                if not config.data_cfg.mixup:
                    data = data_1
                    label = label_1
                else:
                    data = [lam*X + (1-lam)*X[::-1] for X in data_1]
                    label = []
                    for Y in label_1:
                        y1 = label_transform(Y, classes)
                        y2 = label_transform(Y[::-1], classes)
                        label.append(lam*y1 + (1-lam)*y2)

                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                if config.train_cfg.amp:
                    with ag.record():
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                            # scaled_loss.backward()
                else:
                    for l in loss:
                        l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                output_softmax = [nd.SoftmaxActivation(out) for out in output]
                train_metric.update(label, output_softmax)
                metric.update(label_1, output_softmax)
                name, acc = train_metric.get()
                if config.save_cfg.tensorboard:
                    sw.add_scalar(tag='lr', value=trainer.learning_rate,
                                  global_step=iteration)
                if epoch == 0 and iteration == 1 and config.save_cfg.profiler:
                    nd.waitall()
                    profiler.set_state('stop')
                    profiler.dump()
                iteration += 1
                sig_pgbar.emit(iteration)
                if check_flag()[0]:
                    sig_state.emit(2)
                while(check_flag()[0] or check_flag()[1]):
                    if check_flag()[1]:
                        print('stop')
                        return
                    else:
                        time.sleep(5)
                        print('pausing')

            epoch_time = time.time() - tic
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            _, train_acc = metric.get()
            name, val_acc, _ = test(ctx, val_data)
            # if config.data_cfg.mixup:
            #     train_history.update([acc, 1-val_acc])
            #     plt.cla()
            #     train_history.plot(save_path='%s/%s_history.png' %
            #                        (plot_name, model_name))
            # else:
            train_history.update([1-train_acc, 1-val_acc])
            plt.cla()
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_name, model_name))

            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params' %
                                    (save_dir, best_val_score, model_name, epoch))

            current_lr = trainer.learning_rate
            name, val_acc, val_loss = test(ctx, val_data)

            logging.info('[Epoch %d] loss=%f train_acc=%f train_RMSE=%f\n     val_acc=%f val_loss=%f lr=%f time: %f' %
                         (epoch, train_loss, train_acc, acc, val_acc, val_loss, current_lr, epoch_time))
            loss_history.update([train_loss, val_loss])
            plt.cla()
            loss_history.plot(save_path='%s/%s_loss.png' %
                              (plot_name, model_name), y_lim=(0, 2), legend_loc='best')
            if config.save_cfg.tensorboard:
                sw._add_scalars(tag='Acc',
                                scalar_dict={'train_acc': train_acc, 'test_acc': val_acc}, global_step=epoch)
                sw._add_scalars(tag='Loss',
                                scalar_dict={'train_loss': train_loss, 'test_loss': val_loss}, global_step=epoch)

            sig_table.emit([epoch, train_loss, train_acc,
                            val_loss, val_acc, current_lr, epoch_time])
            csv_writer.writerow([epoch, train_loss, train_acc,
                                 val_loss, val_acc, current_lr, epoch_time])
            csv_file.flush()

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                net.save_parameters('%s/cifar10-%s-%d.params' %
                                    (save_dir, model_name, epoch))
        if save_period and save_dir:
            net.save_parameters('%s/cifar10-%s-%d.params' %
                                (save_dir, model_name, epochs-1))

    train(num_epochs, context)
    if config.save_cfg.tensorboard:
        sw.close()

    for ctx in context:
        ctx.empty_cache()

    csv_file.close()
    logging.shutdown()
    reload(logging)
    sig_state.emit(0)
def main():
    opt = parse_args(parser)

    assert not (os.path.isdir(opt.save_dir)), "already done this experiment..."
    Path(opt.save_dir).mkdir(parents=True)

    filehandler = logging.FileHandler(
        os.path.join(opt.save_dir, opt.logging_file))
    streamhandler = logging.StreamHandler()
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    logger.info(opt)

    sw = SummaryWriter(logdir=opt.save_dir, flush_secs=5, verbose=False)

    if opt.use_amp:
        amp.init()

    batch_size = opt.batch_size
    classes = opt.num_classes

    # num_gpus = opt.num_gpus
    # batch_size *= max(1, num_gpus)
    # logger.info('Total batch size is set to %d on %d GPUs' % (batch_size, num_gpus))
    # context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    # num_workers = opt.num_workers

    num_gpus = 1
    context = [mx.gpu(i) for i in range(num_gpus)]
    per_device_batch_size = 5
    num_workers = 12
    batch_size = per_device_batch_size * num_gpus

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]

    if opt.slowfast:
        optimizer = 'nag'
    else:
        optimizer = 'sgd'

    if opt.clip_grad > 0:
        optimizer_params = {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum,
            'clip_gradient': opt.clip_grad
        }
    else:
        # optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}
        optimizer_params = {'wd': opt.wd, 'momentum': opt.momentum}

    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    model_name = opt.model
    if opt.use_pretrained and len(opt.hashtag) > 0:
        opt.use_pretrained = opt.hashtag
    net = get_model(name=model_name,
                    nclass=classes,
                    pretrained=opt.use_pretrained,
                    use_tsn=opt.use_tsn,
                    num_segments=opt.num_segments,
                    partial_bn=opt.partial_bn,
                    bn_frozen=opt.freeze_bn)
    # net.cast(opt.dtype)
    net.collect_params().reset_ctx(context)
    logger.info(net)

    resume_params = find_model_params(opt)
    if resume_params is not '':
        net.load_parameters(resume_params, ctx=context)
        print('Continue training from model %s.' % (resume_params))

    train_data, val_data, batch_fn = get_data_loader(opt, batch_size,
                                                     num_workers, logger)

    iterations_per_epoch = len(train_data) // opt.accumulate
    lr_scheduler = CyclicalSchedule(CosineAnnealingSchedule,
                                    min_lr=0,
                                    max_lr=opt.lr,
                                    cycle_length=opt.T_0 *
                                    iterations_per_epoch,
                                    cycle_length_decay=opt.T_mult,
                                    cycle_magnitude_decay=1)
    optimizer_params['lr_scheduler'] = lr_scheduler

    optimizer = mx.optimizer.SGD(**optimizer_params)
    train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    def test(ctx, val_data, kvstore="None"):
        acc_top1.reset()
        acc_top5.reset()
        #get weights
        weights = get_weights(opt).reshape(1, opt.num_classes)
        weights = mx.nd.array(weights, ctx=mx.gpu(0))

        L = gluon.loss.SoftmaxCrossEntropyLoss()

        num_test_iter = len(val_data)
        val_loss_epoch = 0
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = []
            for _, X in enumerate(data):
                X = X.reshape((-1, ) + X.shape[2:])
                pred = net(X.astype(opt.dtype, copy=False))
                outputs.append(pred)

            if (opt.balanced):
                loss = [
                    L(yhat, y.astype(opt.dtype, copy=False), weights)
                    for yhat, y in zip(outputs, label)
                ]
            else:
                loss = [
                    L(yhat, y.astype(opt.dtype, copy=False))
                    for yhat, y in zip(outputs, label)
                ]

            # loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]

            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

            val_loss_epoch += sum([l.mean().asscalar()
                                   for l in loss]) / len(loss)

            if opt.log_interval and not (i + 1) % opt.log_interval:
                _, top1 = acc_top1.get()
                _, top5 = acc_top5.get()
                logger.info('Batch [%04d]/[%04d]: acc-top1=%f acc-top5=%f' %
                            (i, num_test_iter, top1 * 100, top5 * 100))

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        val_loss = val_loss_epoch / num_test_iter

        return (top1, top5, val_loss)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        if opt.partial_bn:
            train_patterns = "None"
            if 'inceptionv3' in opt.model:
                train_patterns = '.*weight|.*bias|inception30_batchnorm0_gamma|inception30_batchnorm0_beta|inception30_batchnorm0_running_mean|inception30_batchnorm0_running_var'
            elif 'inceptionv1' in opt.model:
                train_patterns = '.*weight|.*bias|googlenet0_batchnorm0_gamma|googlenet0_batchnorm0_beta|googlenet0_batchnorm0_running_mean|googlenet0_batchnorm0_running_var'
            else:
                logger.info(
                    'Current model does not support partial batch normalization.'
                )

            # trainer = gluon.Trainer(net.collect_params(train_patterns), optimizer, optimizer_params, update_on_kvstore=False)
            trainer = gluon.Trainer(net.collect_params(train_patterns),
                                    optimizer,
                                    update_on_kvstore=False)

        elif opt.freeze_bn:
            train_patterns = '.*weight|.*bias'
            # trainer = gluon.Trainer(net.collect_params(train_patterns), optimizer, optimizer_params, update_on_kvstore=False)
            trainer = gluon.Trainer(net.collect_params(train_patterns),
                                    optimizer,
                                    update_on_kvstore=False)

        else:
            # trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params, update_on_kvstore=False)
            trainer = gluon.Trainer(net.collect_params(),
                                    optimizer,
                                    update_on_kvstore=False)

        if opt.accumulate > 1:
            params = [
                p for p in net.collect_params().values()
                if p.grad_req != 'null'
            ]
            for p in params:
                p.grad_req = 'add'

        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.use_amp:
            amp.init_trainer(trainer)

        L = gluon.loss.SoftmaxCrossEntropyLoss()

        best_val_score = 0
        lr_decay_count = 0
        #compute weights
        weights = get_weights(opt).reshape(1, opt.num_classes)
        weights = mx.nd.array(weights, ctx=mx.gpu(0))

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            train_metric.reset()
            btic = time.time()
            num_train_iter = len(train_data)
            train_loss_epoch = 0
            train_loss_iter = 0

            for i, batch in tqdm(enumerate(train_data)):
                data, label = batch_fn(batch, ctx)

                with ag.record():
                    outputs = []
                    for _, X in enumerate(data):
                        X = X.reshape((-1, ) + X.shape[2:])
                        # pred = net(X.astype(opt.dtype, copy=False))
                        pred = net(X)
                        outputs.append(pred)
                    if (opt.balanced):
                        loss = [
                            L(yhat, y.astype(opt.dtype, copy=False), weights)
                            for yhat, y in zip(outputs, label)
                        ]

                    else:
                        loss = [
                            L(yhat, y.astype(opt.dtype, copy=False))
                            for yhat, y in zip(outputs, label)
                        ]

                    if opt.use_amp:
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                    else:
                        ag.backward(loss)

                if opt.accumulate > 1:
                    if (i + 1) % opt.accumulate == 0:
                        trainer.step(batch_size * opt.accumulate)
                        net.collect_params().zero_grad()
                else:
                    trainer.step(batch_size)

                train_metric.update(label, outputs)
                train_loss_iter = sum([l.mean().asscalar()
                                       for l in loss]) / len(loss)
                train_loss_epoch += train_loss_iter

                train_metric_name, train_metric_score = train_metric.get()
                sw.add_scalar(tag='train_acc_top1_iter',
                              value=train_metric_score * 100,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='train_loss_iter',
                              value=train_loss_iter,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='learning_rate_iter',
                              value=trainer.learning_rate,
                              global_step=epoch * num_train_iter + i)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    logger.info(
                        'Epoch[%03d] Batch [%04d]/[%04d]\tSpeed: %f samples/sec\t %s=%f\t loss=%f\t lr=%f'
                        % (epoch, i, num_train_iter,
                           batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score * 100, train_loss_epoch /
                           (i + 1), trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))
            mx.ndarray.waitall()

            logger.info('[Epoch %03d] training: %s=%f\t loss=%f' %
                        (epoch, train_metric_name, train_metric_score * 100,
                         train_loss_epoch / num_train_iter))
            logger.info('[Epoch %03d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            sw.add_scalar(tag='train_loss_epoch',
                          value=train_loss_epoch / num_train_iter,
                          global_step=epoch)

            if not opt.train_only:
                acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data)

                logger.info(
                    '[Epoch %03d] validation: acc-top1=%f acc-top5=%f loss=%f'
                    %
                    (epoch, acc_top1_val * 100, acc_top5_val * 100, loss_val))
                sw.add_scalar(tag='val_loss_epoch',
                              value=loss_val,
                              global_step=epoch)
                sw.add_scalar(tag='val_acc_top1_epoch',
                              value=acc_top1_val * 100,
                              global_step=epoch)

                if acc_top1_val > best_val_score:
                    best_val_score = acc_top1_val
                    net.save_parameters('%s/%.4f-%s-%s-%03d-best.params' %
                                        (opt.save_dir, best_val_score,
                                         opt.dataset, model_name, epoch))
                    trainer.save_states('%s/%.4f-%s-%s-%03d-best.states' %
                                        (opt.save_dir, best_val_score,
                                         opt.dataset, model_name, epoch))
                # else:
                #     if opt.save_frequency and opt.save_dir and (epoch + 1) % opt.save_frequency == 0:
                #         net.save_parameters('%s/%s-%s-%03d.params'%(opt.save_dir, opt.dataset, model_name, epoch))
                #         trainer.save_states('%s/%s-%s-%03d.states'%(opt.save_dir, opt.dataset, model_name, epoch))

        # # save the last model
        # net.save_parameters('%s/%s-%s-%03d.params'%(opt.save_dir, opt.dataset, model_name, opt.num_epochs-1))
        # trainer.save_states('%s/%s-%s-%03d.states'%(opt.save_dir, opt.dataset, model_name, opt.num_epochs-1))
        def return_float(el):
            return float(el)

        try:
            #remove "trash" files
            performances = [
                get_file_stem(file).split("-")[0]
                for file in os.listdir(opt.save_dir) if "params" in file
            ]

            best_performance = sorted(performances,
                                      key=return_float,
                                      reverse=True)[0]

            params_trash = [
                os.path.join(opt.save_dir, file)
                for file in os.listdir(opt.save_dir)
                if (("params" in file) and not (best_performance in file))
            ]
            states_trash = [
                os.path.join(opt.save_dir, file)
                for file in os.listdir(opt.save_dir)
                if (("states" in file) and not (best_performance in file))
            ]
            trash_files = params_trash + states_trash

            for file in trash_files:
                os.remove(file)
        except:
            print("Sth went wrong...")

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)

    train(context)
    sw.close()
Ejemplo n.º 14
0
        for name, param in net.collect_params().items():
            try:
                sw.add_histogram(tag = name + "_grad", values = param.grad(), global_step = global_step, bins = 1000)
            except:
                print(name)
                print(param.grad())

        # compute validation loss
        compute_val_loss(net, val_loader, loss_function, sw, epoch)

        # evaluate the model on testing set
        evaluate(net, test_loader, true_value, num_of_vertices, sw, epoch)

        params_filename = os.path.join(params_path, '%s_epoch_%s.params'%(model_name, epoch))
        net.save_parameters(params_filename)
        print('save parameters to file: %s'%(params_filename))
    
    # close SummaryWriter
    sw.close()

    if 'prediction_filename' in training_config:
        prediction_path = training_config['prediction_filename']

        prediction = predict(net, test_loader)

        np.savez_compressed(
            os.path.normpath(prediction_path), 
            prediction = prediction,
            ground_truth = all_data['test']['target']
        )
Ejemplo n.º 15
0
    return stat


monitor = mx.mon.Monitor(100,
                         monitor_fc1_gradient,
                         pattern='fc1_backward_weight')


def monitor_fc1_weight(param):
    if param.nbatch % 100 == 0:
        arg_params, aux_params = param.locals['self'].get_params()
        summary_writer.add_scalar(
            tag='fc1-weight',
            value=arg_params['fc1_weight'].asnumpy().flatten())


batch_end_callbacks.append(monitor_fc1_weight)

model.fit(train_data=train_iter,
          begin_epoch=0,
          num_epoch=20,
          eval_data=test_iter,
          eval_metric='accuracy',
          optimizer='sgd',
          optimizer_params=optimizer_params,
          initializer=mx.init.Uniform(),
          batch_end_callback=batch_end_callbacks,
          eval_end_callback=eval_end_callbacks,
          monitor=monitor)
summary_writer.close()
Ejemplo n.º 16
0
def main():
    opt = parse_args()

    makedirs(opt.save_dir)

    filehandler = logging.FileHandler(os.path.join(opt.save_dir, opt.logging_file))
    streamhandler = logging.StreamHandler()
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    logger.info(opt)

    sw = SummaryWriter(logdir=opt.save_dir, flush_secs=5)

    batch_size = opt.batch_size
    classes = opt.num_classes

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    logger.info('Total batch size is set to %d on %d GPUs' % (batch_size, num_gpus))
    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]

    optimizer = 'sgd'
    if opt.clip_grad > 0:
        optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'clip_gradient': opt.clip_grad}
    else:
        optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}

    model_name = opt.model
    net = get_model(name=model_name, nclass=classes, pretrained=opt.use_pretrained,
                    tsn=opt.use_tsn, num_segments=opt.num_segments, partial_bn=opt.partial_bn)
    net.cast(opt.dtype)
    net.collect_params().reset_ctx(context)
    logger.info(net)

    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    train_data, val_data, batch_fn = get_data_loader(opt, batch_size, num_workers, logger)

    train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    def test(ctx, val_data):
        acc_top1.reset()
        acc_top5.reset()
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        return (top1, top5)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        if opt.partial_bn:
            train_patterns = None
            if 'inceptionv3' in opt.model:
                train_patterns = '.*weight|.*bias|inception30_batchnorm0_gamma|inception30_batchnorm0_beta|inception30_batchnorm0_running_mean|inception30_batchnorm0_running_var'
            else:
                logger.info('Current model does not support partial batch normalization.')
            trainer = gluon.Trainer(net.collect_params(train_patterns), optimizer, optimizer_params)
        else:
            trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)

        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        L = gluon.loss.SoftmaxCrossEntropyLoss()

        best_val_score = 0
        lr_decay_count = 0

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            train_metric.reset()
            btic = time.time()

            if epoch == lr_decay_epoch[lr_decay_count]:
                trainer.set_learning_rate(trainer.learning_rate * lr_decay)
                lr_decay_count += 1

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                with ag.record():
                    outputs = [net(X.astype(opt.dtype, copy=False)) for X in data]
                    loss = [L(yhat, y.astype(opt.dtype, copy=False)) for yhat, y in zip(outputs, label)]

                for l in loss:
                    l.backward()

                trainer.step(batch_size)
                train_metric.update(label, outputs)

                if opt.log_interval and not (i+1) % opt.log_interval:
                    train_metric_name, train_metric_score = train_metric.get()
                    logger.info('Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f' % (
                                epoch, i, batch_size*opt.log_interval/(time.time()-btic),
                                train_metric_name, train_metric_score*100, trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i /(time.time() - tic))

            acc_top1_val, acc_top5_val = test(ctx, val_data)

            logger.info('[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score*100))
            logger.info('[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))
            logger.info('[Epoch %d] validation: acc-top1=%f acc-top5=%f'%(epoch, acc_top1_val*100, acc_top5_val*100))

            sw.add_scalar(tag='train_acc', value=train_metric_score*100, global_step=epoch)
            sw.add_scalar(tag='valid_acc', value=acc_top1_val*100, global_step=epoch)

            if acc_top1_val > best_val_score:
                best_val_score = acc_top1_val
                if opt.use_tsn:
                    net.basenet.save_parameters('%s/%.4f-ucf101-%s-%03d-best.params'%(opt.save_dir, best_val_score, model_name, epoch))
                else:
                    net.save_parameters('%s/%.4f-ucf101-%s-%03d-best.params'%(opt.save_dir, best_val_score, model_name, epoch))
                trainer.save_states('%s/%.4f-ucf101-%s-%03d-best.states'%(opt.save_dir, best_val_score, model_name, epoch))

            if opt.save_frequency and opt.save_dir and (epoch + 1) % opt.save_frequency == 0:
                if opt.use_tsn:
                    net.basenet.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, epoch))
                else:
                    net.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, epoch))
                trainer.save_states('%s/ucf101-%s-%03d.states'%(opt.save_dir, model_name, epoch))

        # save the last model
        if opt.use_tsn:
            net.basenet.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, opt.num_epochs-1))
        else:
            net.save_parameters('%s/ucf101-%s-%03d.params'%(opt.save_dir, model_name, opt.num_epochs-1))
        trainer.save_states('%s/ucf101-%s-%03d.states'%(opt.save_dir, model_name, opt.num_epochs-1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)

    train(context)
    sw.close()
Ejemplo n.º 17
0
class BaseTrainer:
    def __init__(self, config, model, criterion, ctx):
        config['trainer']['output_dir'] = os.path.join(
            str(pathlib.Path(os.path.abspath(__name__)).parent),
            config['trainer']['output_dir'])
        config['name'] = config['name'] + '_' + model.model_name
        self.save_dir = os.path.join(config['trainer']['output_dir'],
                                     config['name'])
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')

        if config['trainer']['resume_checkpoint'] == '' and config['trainer'][
                'finetune_checkpoint'] == '':
            shutil.rmtree(self.save_dir, ignore_errors=True)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        # 保存本次实验的alphabet 到模型保存的地方
        np.save(os.path.join(self.save_dir, 'alphabet.npy'),
                config['data_loader']['args']['dataset']['alphabet'])
        self.global_step = 0
        self.start_epoch = 1
        self.config = config

        self.model = model
        self.criterion = criterion
        # logger and tensorboard
        self.tensorboard_enable = self.config['trainer']['tensorboard']
        self.epochs = self.config['trainer']['epochs']
        self.display_interval = self.config['trainer']['display_interval']
        if self.tensorboard_enable:
            from mxboard import SummaryWriter
            self.writer = SummaryWriter(self.save_dir, verbose=False)

        self.logger = setup_logger(os.path.join(self.save_dir, 'train_log'))
        self.logger.info(pformat(self.config))
        self.logger.info(self.model)
        # device set
        self.ctx = ctx
        mx.random.seed(2)  # 设置随机种子

        self.logger.info('train with mxnet: {} and device: {}'.format(
            mx.__version__, self.ctx))
        self.metrics = {
            'val_acc': 0,
            'train_loss': float('inf'),
            'best_model': ''
        }

        schedule = self._initialize('lr_scheduler', mx.lr_scheduler)
        optimizer = self._initialize('optimizer',
                                     mx.optimizer,
                                     lr_scheduler=schedule)
        self.trainer = gluon.Trainer(self.model.collect_params(),
                                     optimizer=optimizer)

        if self.config['trainer']['resume_checkpoint'] != '':
            self._laod_checkpoint(self.config['trainer']['resume_checkpoint'],
                                  resume=True)
        elif self.config['trainer']['finetune_checkpoint'] != '':
            self._laod_checkpoint(
                self.config['trainer']['finetune_checkpoint'], resume=False)

        if self.tensorboard_enable:
            try:
                # add graph
                from mxnet.gluon import utils as gutils
                dummy_input = gutils.split_and_load(
                    nd.zeros((
                        1, self.config['data_loader']['args']['dataset']
                        ['img_channel'],
                        self.config['data_loader']['args']['dataset']['img_h'],
                        self.config['data_loader']['args']['dataset']['img_w']
                    )), ctx)
                self.model(dummy_input[0])
                self.writer.add_graph(model)
            except:
                self.logger.error(traceback.format_exc())
                self.logger.warn('add graph to tensorboard failed')

    def train(self):
        """
        Full training logic
        """
        try:
            for epoch in range(self.start_epoch, self.epochs + 1):
                self.epoch_result = self._train_epoch(epoch)
                self._on_epoch_finish()
        except:
            self.logger.error(traceback.format_exc())
        if self.tensorboard_enable:
            self.writer.close()
        self._on_train_finish()

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current epoch number
        """
        raise NotImplementedError

    def _eval(self):
        """
        eval logic for an epoch

        :param epoch: Current epoch number
        """
        raise NotImplementedError

    def _on_epoch_finish(self):
        raise NotImplementedError

    def _on_train_finish(self):
        raise NotImplementedError

    def _save_checkpoint(self, epoch, file_name, save_best=False):
        """
        保存模型和检查点信息,会保存模型权重,trainer状态,其他的信息
        :param epoch: 当前epoch
        :param file_name: 文件名
        :param save_best: 是否是最优模型
        :return:
        """

        # 保存权重
        params_filename = os.path.join(self.checkpoint_dir, file_name)
        self.model.save_parameters(params_filename)
        # 保存trainer状态
        trainer_filename = params_filename.replace('.params', '.train_states')
        self.trainer.save_states(trainer_filename)
        # 其他信息
        state = {
            'epoch': epoch,
            'global_step': self.global_step,
            'config': self.config,
            'metrics': self.metrics
        }
        other_filename = params_filename.replace('.params', '.info')
        pickle.dump(state, open(other_filename, 'wb'))
        if save_best:
            shutil.copy(params_filename,
                        os.path.join(self.checkpoint_dir, 'model_best.params'))
            shutil.copy(
                trainer_filename,
                os.path.join(self.checkpoint_dir, 'model_best.train_states'))
            shutil.copy(other_filename,
                        os.path.join(self.checkpoint_dir, 'model_best.info'))
            self.logger.info("Saving current best: {}".format(
                os.path.join(self.checkpoint_dir, 'model_best.params')))
        else:
            self.logger.info("Saving checkpoint: {}".format(params_filename))

    def _laod_checkpoint(self, checkpoint_path, resume):
        """
        从检查点钟加载模型,会加载模型权重,trainer状态,其他的信息
        :param resume_path: 检查点地址
        :return:
        """
        self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path))

        # 加载模型参数
        self.model.load_parameters(checkpoint_path,
                                   ctx=self.ctx,
                                   ignore_extra=True,
                                   allow_missing=True)
        if resume:
            # 加载trainer状态
            trainer_filename = checkpoint_path.replace('.params',
                                                       '.train_states')
            if os.path.exists(trainer_filename):
                self.trainer.load_states(trainer_filename)

            # 加载其他信息
            other_filename = checkpoint_path.replace('.params', '.info')
            checkpoint = pickle.load(open(other_filename, 'rb'))
            self.start_epoch = checkpoint['epoch'] + 1
            self.global_step = checkpoint['global_step']
            self.metrics = checkpoint['metrics']
            self.logger.info("resume from checkpoint {} (epoch {})".format(
                checkpoint_path, self.start_epoch))
        else:
            self.logger.info(
                "finetune from checkpoint {}".format(checkpoint_path))

    def _initialize(self, name, module, *args, **kwargs):
        module_name = self.config[name]['type']
        module_args = self.config[name]['args']
        assert all([
            k not in module_args for k in kwargs
        ]), 'Overwriting kwargs given in config file is not allowed'
        module_args.update(kwargs)
        return getattr(module, module_name)(*args, **module_args)
Ejemplo n.º 18
0
class TrainerAgentMXNET:  # Probably needs refactoring
    """Main training loop"""
    def __init__(self,
                 model,
                 symbol,
                 val_iter,
                 train_config: TrainConfig,
                 train_objects: TrainObjects,
                 use_rtpt: bool,
                 augment=False):
        """
        Class for training the neural network.
        :param model: The model loaded with the MXNet Module functionalities.
        :param symbol: The architecture of the neural network.
        :param val_iter: Iteratable object over the validation data.
        :param train_config: An instance of the TrainConfig data class.
        :param train_objects: Am omstamce pf the TrainObject data class.
        :param use_rtpt: If True, an RTPT object will be created and modified within this class.
        """
        # Too many instance attributes (29/7) - Too many arguments (24/5) - Too many local variables (25/15)
        # Too few public methods (1/2)
        self.tc = train_config
        self.to = train_objects
        if self.to.metrics is None:
            self.to.metrics = {}
        self._model = model
        self._symbol = symbol
        self._val_iter = val_iter
        self.x_train = self.yv_train = self.yp_train = None
        self._ctx = get_context(train_config.context, train_config.device_id)
        self._augment = augment

        # define a summary writer that logs data and flushes to the file every 5 seconds
        if self.tc.log_metrics_to_tensorboard:
            self.sum_writer = SummaryWriter(logdir=self.tc.export_dir + "logs",
                                            flush_secs=5,
                                            verbose=False)
        # Define the optimizer
        if self.tc.optimizer_name == "adam":
            self.optimizer = mx.optimizer.Adam(
                learning_rate=0.001,
                beta1=0.9,
                beta2=0.999,
                epsilon=1e-8,
                lazy_update=True,
                rescale_grad=(1.0 / self.tc.batch_size))
        elif self.tc.optimizer_name == "nag":
            self.optimizer = mx.optimizer.NAG(
                momentum=self.to.momentum_schedule(0),
                wd=self.tc.wd,
                rescale_grad=(1.0 / self.tc.batch_size))
        else:
            raise Exception("%s is currently not supported as an optimizer." %
                            self.tc.optimizer_name)
        self.ordering = list(
            range(self.tc.nb_parts)
        )  # define a list which describes the order of the processed batches
        # if we augment the data set each part is loaded twice
        if self._augment:
            self.ordering += self.ordering
        # decides if the policy indices shall be selected directly from spatial feature maps without dense layer
        self.batch_end_callbacks = [self.batch_callback]

        # few variables which are internally used
        self.val_loss_best = self.val_p_acc_best = self.k_steps_best = \
            self.old_label = self.value_out = self.t_s = None
        self.patience_cnt = self.batch_proc_tmp = None
        # calculate how many log states will be processed
        self.k_steps_end = round(self.tc.total_it / self.tc.batch_steps)
        if self.k_steps_end == 0:
            self.k_steps_end = 1
        self.k_steps = self.cur_it = self.nb_spikes = self.old_val_loss = self.continue_training = self.t_s_steps = None
        self._train_iter = self.graph_exported = self.val_metric_values = self.val_loss = self.val_p_acc = None
        self.val_metric_values_best = None

        self.use_rtpt = use_rtpt

        if use_rtpt:
            # we use k-steps instead of epochs here
            self.rtpt = RTPT(name_initials=self.tc.name_initials,
                             experiment_name='crazyara',
                             max_iterations=self.k_steps_end -
                             self.tc.k_steps_initial)

    def _log_metrics(self, metric_values, global_step, prefix="train_"):
        """
        Logs a dictionary object of metric value to the console and to tensorboard
        if _log_metrics_to_tensorboard is set to true
        :param metric_values: Dictionary object storing the current metrics
        :param global_step: X-Position point of all metric entries
        :param prefix: Used for labelling the metrics
        :return:
        """
        for name in metric_values.keys():  # show the metric stats
            print(" - %s%s: %.4f" % (prefix, name, metric_values[name]),
                  end="")
            # add the metrics to the tensorboard event file
            if self.tc.log_metrics_to_tensorboard:
                self.sum_writer.add_scalar(
                    name, [prefix.replace("_", ""), metric_values[name]],
                    global_step)

    def train(self, cur_it=None):  # Probably needs refactoring
        """
        Training model
        :param cur_it: Current iteration which is used for the learning rate and momentum schedule.
         If set to None it will be initialized
        :return: return_metrics_and_stop_training()
        """
        # Too many local variables (44/15) - Too many branches (18/12) - Too many statements (108/50)
        # set a custom seed for reproducibility
        if self.tc.seed is not None:
            random.seed(self.tc.seed)
        # define and initialize the variables which will be used
        self.t_s = time()
        # track on how many batches have been processed in this epoch
        self.patience_cnt = epoch = self.batch_proc_tmp = 0
        self.k_steps = self.tc.k_steps_initial  # counter for thousands steps

        if cur_it is None:
            self.cur_it = self.tc.k_steps_initial * 1000
        else:
            self.cur_it = cur_it
        self.nb_spikes = 0  # count the number of spikes that have been detected
        # initialize the loss to compare with, with a very high value
        self.old_val_loss = 9000
        self.graph_exported = False  # create a state variable to check if the net architecture has been reported yet
        self.continue_training = True
        self.optimizer.lr = self.to.lr_schedule(self.cur_it)
        if self.tc.optimizer_name == "nag":
            self.optimizer.momentum = self.to.momentum_schedule(self.cur_it)

        if not self.ordering:  # safety check to prevent eternal loop
            raise Exception(
                "You must have at least one part file in your planes-dataset directory!"
            )

        if self.use_rtpt:
            # Start the RTPT tracking
            self.rtpt.start()

        while self.continue_training:  # Too many nested blocks (7/5)
            # reshuffle the ordering of the training game batches (shuffle works in place)
            random.shuffle(self.ordering)

            epoch += 1
            logging.info("EPOCH %d", epoch)
            logging.info("=========================")
            self.t_s_steps = time()
            self._model.init_optimizer(optimizer=self.optimizer)

            if self._augment:
                # stores part ids that were not augmented yet
                parts_not_augmented = list(set(self.ordering.copy()))
                # stores part ids that were loaded before but not augmented
                parts_to_augment = []

            for part_id in tqdm_notebook(self.ordering):

                if MODE == MODE_XIANGQI:
                    _, self.x_train, self.yv_train, self.yp_train, _ = load_xiangqi_dataset(
                        dataset_type="train",
                        part_id=part_id,
                        normalize=self.tc.normalize,
                        verbose=False)
                    if self._augment:
                        # check whether the current part should be augmented
                        if part_id in parts_to_augment:
                            augment(self.x_train, self.yp_train)
                            logging.debug(
                                "Using augmented part with id {}".format(
                                    part_id))
                        elif part_id in parts_not_augmented:
                            if random.randint(0, 1):
                                augment(self.x_train, self.yp_train)
                                parts_not_augmented.remove(part_id)
                                logging.debug(
                                    "Using augmented part with id {}".format(
                                        part_id))
                            else:
                                parts_to_augment.append(part_id)
                                logging.debug(
                                    "Using unaugmented part with id {}".format(
                                        part_id))
                else:
                    # load one chunk of the dataset from memory
                    _, self.x_train, self.yv_train, self.yp_train, plys_to_end, _ = load_pgn_dataset(
                        dataset_type="train",
                        part_id=part_id,
                        normalize=self.tc.normalize,
                        verbose=False,
                        q_value_ratio=self.tc.q_value_ratio)
                # fill_up_batch if there aren't enough games
                if len(self.yv_train) < self.tc.batch_size:
                    logging.info("filling up batch with too few samples %d" %
                                 len(self.yv_train))
                    self.x_train = fill_up_batch(self.x_train,
                                                 self.tc.batch_size)
                    self.yv_train = fill_up_batch(self.yv_train,
                                                  self.tc.batch_size)
                    self.yp_train = fill_up_batch(self.yp_train,
                                                  self.tc.batch_size)
                    if MODE != MODE_XIANGQI:
                        if plys_to_end is not None:
                            plys_to_end = fill_up_batch(
                                plys_to_end, self.tc.batch_size)

                if MODE != MODE_XIANGQI:
                    if self.tc.discount != 1:
                        self.yv_train *= self.tc.discount**plys_to_end
                self.yp_train = prepare_policy(
                    self.yp_train, self.tc.select_policy_from_plane,
                    self.tc.sparse_policy_label,
                    self.tc.is_policy_from_plane_data)

                if self.tc.use_wdl and self.tc.use_plys_to_end:
                    self._train_iter = mx.io.NDArrayIter(
                        {'data': self.x_train}, {
                            'value_label': self.yv_train,
                            'policy_label': self.yp_train,
                            'wdl_label': value_to_wdl_label(self.yv_train),
                            'plys_to_end_label':
                            prepare_plys_label(plys_to_end)
                        },
                        self.tc.batch_size,
                        shuffle=True)
                else:
                    self._train_iter = mx.io.NDArrayIter(
                        {'data': self.x_train}, {
                            'value_label': self.yv_train,
                            'policy_label': self.yp_train
                        },
                        self.tc.batch_size,
                        shuffle=True)

                # avoid memory leaks by adding synchronization
                mx.nd.waitall()

                reset_metrics(self.to.metrics)
                for batch in self._train_iter:
                    self._model.forward(batch,
                                        is_train=True)  # compute predictions
                    for metric in self.to.metrics:  # update the metrics
                        self._model.update_metric(metric, batch.label)

                    self._model.backward()
                    # compute gradients
                    self._model.update()  # update parameters
                    self.batch_callback()

                    if not self.continue_training:
                        logging.info('Elapsed time for training(hh:mm:ss): ' +
                                     str(
                                         datetime.timedelta(
                                             seconds=round(time() -
                                                           self.t_s))))

                        return return_metrics_and_stop_training(
                            self.k_steps, self.val_metric_values,
                            self.k_steps_best, self.val_metric_values_best)

                # add the graph representation of the network to the tensorboard log file
                if not self.graph_exported and self.tc.log_metrics_to_tensorboard:
                    # self.sum_writer.add_graph(self._symbol)
                    self.graph_exported = True

    def _fill_train_metrics(self):
        """
        Fills in the training metrics
        :return:
        """
        self.train_metric_values = {}
        for metric in self.to.metrics:
            name, value = metric.get()
            self.train_metric_values[name] = value

        self.train_metric_values["loss"] = 0.01 * self.train_metric_values["value_loss"] + \
                                           0.99 * self.train_metric_values["policy_loss"]

    def recompute_eval(self):
        """
        Recomputes the score on the validataion data
        :return:
        """
        ms_step = ((time() - self.t_s_steps) / self.tc.batch_steps) * 1000
        logging.info("Step %dK/%dK - %dms/step", self.k_steps,
                     self.k_steps_end, ms_step)
        logging.info("-------------------------")
        logging.debug("Iteration %d/%d", self.cur_it, self.tc.total_it)
        if self.tc.optimizer_name == "nag":
            logging.debug("lr: %.7f - momentum: %.7f", self.optimizer.lr,
                          self.optimizer.momentum)
        else:
            logging.debug("lr: %.7f - momentum: -", self.optimizer.lr)

        # the metric values have already been computed during training for the train set
        self._fill_train_metrics()

        self.val_metric_values = evaluate_metrics(
            self.to.metrics,
            self._val_iter,
            self._model,
        )
        if self.use_rtpt:
            # update process title according to loss
            self.rtpt.step(
                subtitle=f"loss={self.val_metric_values['loss']:2.2f}")
        if self.tc.use_spike_recovery and (
                self.old_val_loss * self.tc.spike_thresh <
                self.val_metric_values["loss"] or np.isnan(
                    self.val_metric_values["loss"])):  # check for spikes
            self.handle_spike()
        else:
            self.update_eval()

    def handle_spike(self):
        """
        Handles the occurence of a spike during training, in the case validation loss increased dramatically.
        :return: self._return_metrics_and_stop_training()
        """
        self.nb_spikes += 1
        logging.warning(
            "Spike %d/%d occurred - val_loss: %.3f",
            self.nb_spikes,
            self.tc.max_spikes,
            self.val_metric_values["loss"],
        )
        if self.nb_spikes >= self.tc.max_spikes:
            val_loss = self.val_metric_values["loss"]
            val_p_acc = self.val_metric_values["policy_acc"]
            # finally stop training because the number of lr drops has been achieved

            logging.debug(
                "The maximum number of spikes has been reached. Stop training."
            )
            self.continue_training = False

            if self.tc.log_metrics_to_tensorboard:
                self.sum_writer.close()
            return return_metrics_and_stop_training(
                self.k_steps, self.val_metric_values, self.k_steps_best,
                self.val_metric_values_best)

        logging.debug("Recover to latest checkpoint")
        # Load the best model once again
        prefix = self.tc.export_dir + "weights/model-%.5f-%.3f" % (
            self.val_loss_best, self.val_p_acc_best)

        logging.debug("load current best model:%s", prefix)
        self._model.load(prefix, epoch=self.k_steps_best)

        self.k_steps = self.k_steps_best
        logging.debug("k_step is back at %d", self.k_steps_best)
        # print the elapsed time
        t_delta = time() - self.t_s_steps
        print(" - %.ds" % t_delta)
        self.t_s_steps = time()

    def update_eval(self):
        """
        Updates the evaluation metrics
        :return:
        """
        # update the val_loss_value to compare with using spike recovery
        self.old_val_loss = self.val_metric_values["loss"]
        # log the metric values to tensorboard
        self._log_metrics(self.train_metric_values,
                          global_step=self.k_steps,
                          prefix="train_")
        self._log_metrics(self.val_metric_values,
                          global_step=self.k_steps,
                          prefix="val_")

        # check if a new checkpoint shall be created
        if self.val_loss_best is None or self.val_metric_values[
                "loss"] < self.val_loss_best:
            # update val_loss_best
            self.val_loss_best = self.val_metric_values["loss"]
            self.val_p_acc_best = self.val_metric_values["policy_acc"]
            self.val_metric_values_best = self.val_metric_values
            self.k_steps_best = self.k_steps

            if self.tc.export_weights:
                prefix = self.tc.export_dir + "weights/model-%.5f-%.3f" % (
                    self.val_loss_best, self.val_p_acc_best)
                # the export function saves both the architecture and the weights
                print()
                self._model.save_checkpoint(prefix, epoch=self.k_steps_best)

            self.patience_cnt = 0  # reset the patience counter
        # print the elapsed time
        t_delta = time() - self.t_s_steps
        print(" - %.ds" % t_delta)
        self.t_s_steps = time()

        # log the samples per second metric to tensorboard
        self.sum_writer.add_scalar(
            tag="samples_per_second",
            value={
                "hybrid_sync":
                self.tc.batch_size * self.tc.batch_steps / t_delta
            },
            global_step=self.k_steps,
        )

        # log the current learning rate
        self.sum_writer.add_scalar(tag="lr",
                                   value=self.to.lr_schedule(self.cur_it),
                                   global_step=self.k_steps)
        if self.tc.optimizer_name == "nag":
            # log the current momentum value
            self.sum_writer.add_scalar(tag="momentum",
                                       value=self.to.momentum_schedule(
                                           self.cur_it),
                                       global_step=self.k_steps)

        if self.cur_it >= self.tc.total_it:

            self.continue_training = False

            self.val_loss = self.val_metric_values["loss"]
            self.val_p_acc = self.val_metric_values["policy_acc"]
            # finally stop training because the number of lr drops has been achieved
            logging.debug("The number of given iterations has been reached")

            if self.tc.log_metrics_to_tensorboard:
                self.sum_writer.close()

    def batch_callback(self):
        """
        Callback which is executed after every batch to update the momentum and learning rate
        :return:
        """

        # update the learning rate and momentum
        self.optimizer.lr = self.to.lr_schedule(self.cur_it)
        if self.tc.optimizer_name == "nag":
            self.optimizer.momentum = self.to.momentum_schedule(self.cur_it)

        self.cur_it += 1
        self.batch_proc_tmp += 1

        if self.batch_proc_tmp >= self.tc.batch_steps:  # show metrics every thousands steps
            self.batch_proc_tmp = self.batch_proc_tmp - self.tc.batch_steps
            # update the counters
            self.k_steps += 1
            self.patience_cnt += 1
            self.recompute_eval()
            self.custom_metric_eval()

    def custom_metric_eval(self):
        """
        Evaluates the model based on the validation set of different variants
        """

        if self.to.variant_metrics is None:
            return

        for part_id, variant_name in enumerate(self.to.variant_metrics):
            # load one chunk of the dataset from memory
            _, x_val, yv_val, yp_val, _, _ = load_pgn_dataset(
                dataset_type="val",
                part_id=part_id,
                normalize=self.tc.normalize,
                verbose=False,
                q_value_ratio=self.tc.q_value_ratio)

            if self.tc.select_policy_from_plane:
                val_iter = mx.io.NDArrayIter({'data': x_val}, {
                    'value_label':
                    yv_val,
                    'policy_label':
                    np.array(FLAT_PLANE_IDX)[yp_val.argmax(axis=1)]
                }, self.tc.batch_size)
            else:
                val_iter = mx.io.NDArrayIter(
                    {'data': x_val}, {
                        'value_label': yv_val,
                        'policy_label': yp_val.argmax(axis=1)
                    }, self.tc.batch_size)

            results = self._model.score(val_iter, self.to.metrics)
            prefix = "val_"

            for entry in results:
                name = variant_name + "_" + entry[0]
                value = entry[1]
                print(" - %s%s: %.4f" % (prefix, name, value), end="")
                # add the metrics to the tensorboard event file
                if self.tc.log_metrics_to_tensorboard:
                    self.sum_writer.add_scalar(
                        name, [prefix.replace("_", ""), value], self.k_steps)
        print()
Ejemplo n.º 19
0
def train():
    if config.restart_training:
        shutil.rmtree(config.output_dir, ignore_errors=True)
    if config.output_dir is None:
        config.output_dir = 'output'
    if not os.path.exists(config.output_dir):
        os.makedirs(config.output_dir)
    logger = setup_logger(os.path.join(config.output_dir, 'train_log'))
    logger.info('train with gpu %s and mxnet %s' %
                (config.gpu_id, mx.__version__))

    ctx = mx.gpu(config.gpu_id)
    # 设置随机种子
    mx.random.seed(2)
    mx.random.seed(2, ctx=ctx)

    train_transfroms = transforms.Compose(
        [transforms.RandomBrightness(0.5),
         transforms.ToTensor()])
    train_dataset = ImageDataset(config.trainfile,
                                 (config.img_h, config.img_w),
                                 3,
                                 80,
                                 config.alphabet,
                                 phase='train')
    train_data_loader = DataLoader(
        train_dataset.transform_first(train_transfroms),
        config.train_batch_size,
        shuffle=True,
        last_batch='keep',
        num_workers=config.workers)
    test_dataset = ImageDataset(config.testfile, (config.img_h, config.img_w),
                                3,
                                80,
                                config.alphabet,
                                phase='test')
    test_data_loader = DataLoader(test_dataset.transform_first(
        transforms.ToTensor()),
                                  config.eval_batch_size,
                                  shuffle=True,
                                  last_batch='keep',
                                  num_workers=config.workers)
    net = CRNN(len(config.alphabet), hidden_size=config.nh)
    net.hybridize()
    if not config.restart_training and config.checkpoint != '':
        logger.info('load pretrained net from {}'.format(config.checkpoint))
        net.load_parameters(config.checkpoint, ctx=ctx)
    else:
        net.initialize(ctx=ctx)

    criterion = gluon.loss.CTCLoss()

    all_step = len(train_data_loader)
    logger.info('each epoch contains {} steps'.format(all_step))
    schedule = mx.lr_scheduler.FactorScheduler(step=config.lr_decay_step *
                                               all_step,
                                               factor=config.lr_decay,
                                               stop_factor_lr=config.end_lr)
    # schedule = mx.lr_scheduler.MultiFactorScheduler(step=[15 * all_step, 30 * all_step, 60 * all_step,80 * all_step],
    #                                                 factor=0.1)
    adam_optimizer = mx.optimizer.Adam(learning_rate=config.lr,
                                       lr_scheduler=schedule)
    trainer = gluon.Trainer(net.collect_params(), optimizer=adam_optimizer)

    sw = SummaryWriter(logdir=config.output_dir)
    for epoch in range(config.start_epoch, config.end_epoch):
        loss = .0
        train_acc = .0
        tick = time.time()
        cur_step = 0
        for i, (data, label) in enumerate(train_data_loader):
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)

            with autograd.record():
                output = net(data)
                loss_ctc = criterion(output, label)
            loss_ctc.backward()
            trainer.step(data.shape[0])

            loss_c = loss_ctc.mean()
            cur_step = epoch * all_step + i
            sw.add_scalar(tag='ctc_loss',
                          value=loss_c.asscalar(),
                          global_step=cur_step // 2)
            sw.add_scalar(tag='lr',
                          value=trainer.learning_rate,
                          global_step=cur_step // 2)
            loss += loss_c
            acc = accuracy(output, label, config.alphabet)
            train_acc += acc
            if (i + 1) % config.display_interval == 0:
                acc /= len(label)
                sw.add_scalar(tag='train_acc', value=acc, global_step=cur_step)
                batch_time = time.time() - tick
                logger.info(
                    '[{}/{}], [{}/{}],step: {}, Speed: {:.3f} samples/sec, ctc loss: {:.4f},acc: {:.4f}, lr:{},'
                    ' time:{:.4f} s'.format(
                        epoch, config.end_epoch, i, all_step, cur_step,
                        config.display_interval * config.train_batch_size /
                        batch_time,
                        loss.asscalar() / config.display_interval, acc,
                        trainer.learning_rate, batch_time))
                loss = .0
                tick = time.time()
                nd.waitall()
        if epoch == 0:
            sw.add_graph(net)
        logger.info('start val ....')
        train_acc /= train_dataset.__len__()
        validation_accuracy = evaluate_accuracy(
            net, test_data_loader, ctx,
            config.alphabet) / test_dataset.__len__()
        sw.add_scalar(tag='val_acc',
                      value=validation_accuracy,
                      global_step=cur_step)
        logger.info("Epoch {},train_acc {:.4f}, val_acc {:.4f}".format(
            epoch, train_acc, validation_accuracy))
        net.save_parameters("{}/{}_{:.4f}_{:.4f}.params".format(
            config.output_dir, epoch, train_acc, validation_accuracy))
    sw.close()
Ejemplo n.º 20
0
def train():
    # 初始化
    ctx = try_gpu(2)
    net = models.resnet50_v1(classes=4)
    net.hybridize()
    net.initialize(ctx=ctx)
    # net.forward(nd.ones((1, 3, 227, 227)).as_in_context(ctx))

    sw = SummaryWriter(
        '/data1/lsp/lsp/pytorch_mnist/log/rotate/mxnet_resnet18')
    # sw.add_graph(net)

    print('initialize weights on', ctx)

    # 获取数据
    batch_size = 64
    epochs = 10

    train_data = custom_dataset(txt='/data2/dataset/image/train.txt',
                                data_shape=(224, 224),
                                channel=3)
    test_data = custom_dataset(txt='/data2/dataset/image/val.txt',
                               data_shape=(224, 224),
                               channel=3)
    transforms_train = transforms.ToTensor()
    # transforms_train = transforms.Compose([transforms.Resize(227), transforms.ToTensor()])
    train_data_loader = gluon.data.DataLoader(
        train_data.transform_first(transforms_train),
        batch_size=batch_size,
        shuffle=True,
        num_workers=12)

    test_data_loader = gluon.data.DataLoader(
        test_data.transform_first(transforms_train),
        batch_size=batch_size,
        shuffle=True,
        num_workers=12)
    # 训练
    criterion = gluon.loss.SoftmaxCrossEntropyLoss()

    steps = train_data.__len__() // batch_size

    schedule = mx.lr_scheduler.FactorScheduler(step=3 * steps,
                                               factor=0.1,
                                               stop_factor_lr=1e-6)
    sgd_optimizer = mx.optimizer.SGD(learning_rate=0.01, lr_scheduler=schedule)
    trainer = gluon.Trainer(net.collect_params(), optimizer=sgd_optimizer)

    for epoch in range(epochs):
        # test_data.reset()
        start = time.time()
        train_loss = 0.0
        train_acc = 0.0
        cur_step = 0
        n = train_data.__len__()
        for i, (data, label) in enumerate(train_data_loader):
            label = label.astype('float32').as_in_context(ctx)
            data = data.as_in_context(ctx)
            with autograd.record():
                outputs = net(data)
                loss = criterion(outputs, label)
            loss.backward()
            trainer.step(batch_size)

            cur_loss = loss.sum().asscalar()
            cur_acc = nd.sum(outputs.argmax(axis=1) == label).asscalar()
            train_acc += cur_acc
            train_loss += cur_loss
            if i % 100 == 0:
                batch_time = time.time() - start
                print(
                    'epoch [%d/%d], Iter: [%d/%d]. Loss: %.4f. Accuracy: %.4f, time:%0.4f, lr:%s'
                    % (epoch, epochs, i, steps, cur_loss, cur_acc / batch_size,
                       batch_time, trainer.learning_rate))
                start = time.time()
            cur_step = epoch * steps + i
            sw.add_scalar(tag='Train/loss',
                          value=cur_loss / label.shape[0],
                          global_step=cur_step)
            sw.add_scalar(tag='Train/acc',
                          value=cur_acc / label.shape[0],
                          global_step=cur_step)
            sw.add_scalar(tag='Train/lr',
                          value=trainer.learning_rate,
                          global_step=cur_step)

        val_acc = evaluate_accuracy(test_data_loader, net, ctx)
        sw.add_scalar(tag='Eval/acc', value=val_acc, global_step=cur_step)
        net.save_parameters("models/resnet501/{}_{}.params".format(
            epoch, val_acc))
        print(
            'epoch: %d, train_loss: %.4f, train_acc: %.4f, val_acc: %.4f, time: %.4f, lr=%s'
            % (epoch, train_loss / n, train_acc / n, val_acc,
               time.time() - start, str(trainer.learning_rate)))
    sw.close()
Ejemplo n.º 21
0
 def check_add_audio(data):
     sw = SummaryWriter(logdir=_LOGDIR)
     sw.add_audio(tag='test_add_audio', audio=data)
     sw.close()
     check_event_file_and_remove_logdir()
class TrainerAgentMXNET:  # Probably needs refactoring
    """Main training loop"""

    # x_train = yv_train = yp_train = None

    def __init__(
        self,
        model,
        symbol,
        val_iter,
        nb_parts,
        lr_schedule,
        momentum_schedule,
        total_it,
        optimizer_name="nag",  # or "adam"
        wd=0.0001,
        batch_steps=1000,
        k_steps_initial=0,
        cpu_count=16,
        batch_size=2048,
        normalize=True,
        export_weights=True,
        export_grad_histograms=True,
        log_metrics_to_tensorboard=True,
        ctx=mx.gpu(),
        metrics=None,  # clip_gradient=60,
        use_spike_recovery=True,
        max_spikes=5,
        spike_thresh=1.5,
        seed=42,
        val_loss_factor=0.01,
        policy_loss_factor=0.99,
        select_policy_from_plane=True,
        discount=1,  # 0.995,
        sparse_policy_label=True,
        q_value_ratio=0,
        cwd=None,
        variant_metrics=None,
        # prefix for the process name in order to identify the process on a server
        name_initials="JC"):
        # Too many instance attributes (29/7) - Too many arguments (24/5) - Too many local variables (25/15)
        # Too few public methods (1/2)
        # , lr_warmup_k_steps=30, lr_warmup_init=0.01):
        if metrics is None:
            metrics = {}
        self._log_metrics_to_tensorboard = log_metrics_to_tensorboard
        self._ctx = ctx
        self._metrics = metrics
        self._model = model
        self._symbol = symbol
        self._graph_exported = False
        self._normalize = normalize
        self._lr_schedule = lr_schedule
        self._momentum_schedule = momentum_schedule
        self._total_it = total_it
        self._batch_size = batch_size
        self._export_grad_histograms = export_grad_histograms
        self._cpu_count = cpu_count
        self._k_steps_initial = k_steps_initial
        self._val_iter = val_iter
        self._export_weights = export_weights
        self._batch_steps = batch_steps
        self._use_spike_recovery = use_spike_recovery
        self._max_spikes = max_spikes
        self._spike_thresh = spike_thresh
        self._seed = seed
        self._val_loss_factor = val_loss_factor
        self._policy_loss_factor = policy_loss_factor
        self.x_train = self.yv_train = self.yp_train = None
        self.discount = discount
        self._q_value_ratio = q_value_ratio
        # defines if the policy target is one-hot encoded (sparse=True) or a target distribution (sparse=False)
        self.sparse_policy_label = sparse_policy_label
        # define the current working directory
        if cwd is None:
            self.cwd = os.getcwd()
        else:
            self.cwd = cwd
        # define a summary writer that logs data and flushes to the file every 5 seconds
        if log_metrics_to_tensorboard:
            self.sum_writer = SummaryWriter(logdir="%s/logs" % self.cwd,
                                            flush_secs=5,
                                            verbose=False)
        # Define the two loss functions
        self.optimizer_name = optimizer_name
        if optimizer_name == "adam":
            self.optimizer = mx.optimizer.Adam(learning_rate=0.001,
                                               beta1=0.9,
                                               beta2=0.999,
                                               epsilon=1e-8,
                                               lazy_update=True,
                                               rescale_grad=(1.0 / batch_size))
        elif optimizer_name == "nag":
            self.optimizer = mx.optimizer.NAG(momentum=momentum_schedule(0),
                                              wd=wd,
                                              rescale_grad=(1.0 / batch_size))
        else:
            raise Exception("%s is currently not supported as an optimizer." %
                            optimizer_name)
        self.ordering = list(
            range(nb_parts)
        )  # define a list which describes the order of the processed batches
        # decides if the policy indices shall be selected directly from spatial feature maps without dense layer
        self.select_policy_from_plane = select_policy_from_plane

        self.batch_end_callbacks = [self.batch_callback]

        # few variables which are internally used
        self.val_loss_best = self.val_p_acc_best = self.k_steps_best = \
            self.old_label = self.value_out = self.t_s = None
        self.patience_cnt = self.batch_proc_tmp = None
        # calculate how many log states will be processed
        self.k_steps_end = self._total_it / self._batch_steps
        self.k_steps = self.cur_it = self.nb_spikes = self.old_val_loss = self.continue_training = self.t_s_steps = None
        self._train_iter = self.graph_exported = self.val_metric_values = self.val_loss = self.val_p_acc = None
        self.variant_metrics = variant_metrics
        self.name_initials = name_initials
        # we use k-steps instead of epochs here
        self.rtpt = RTPT(name_initials=name_initials,
                         base_title='crazyara_training',
                         number_of_epochs=self.k_steps_end,
                         epoch_n=self._k_steps_initial)

    def _log_metrics(self, metric_values, global_step, prefix="train_"):
        """
        Logs a dictionary object of metric value to the console and to tensorboard
        if _log_metrics_to_tensorboard is set to true
        :param metric_values: Dictionary object storing the current metrics
        :param global_step: X-Position point of all metric entries
        :param prefix: Used for labelling the metrics
        :return:
        """
        for name in metric_values.keys():  # show the metric stats
            print(" - %s%s: %.4f" % (prefix, name, metric_values[name]),
                  end="")
            # add the metrics to the tensorboard event file
            if self._log_metrics_to_tensorboard:
                self.sum_writer.add_scalar(
                    name, [prefix.replace("_", ""), metric_values[name]],
                    global_step)

    def train(self, cur_it=None):  # Probably needs refactoring
        """
        Training model
        :param cur_it: Current iteration which is used for the learning rate and momentum schedule.
         If set to None it will be initialized
        :return: self._return_metrics_and_stop_training()
        """
        # Too many local variables (44/15) - Too many branches (18/12) - Too many statements (108/50)
        # set a custom seed for reproducibility
        if self._seed is not None:
            random.seed(self._seed)
        # define and initialize the variables which will be used
        self.t_s = time()
        # track on how many batches have been processed in this epoch
        self.patience_cnt = epoch = self.batch_proc_tmp = 0
        self.k_steps = self._k_steps_initial  # counter for thousands steps

        # inform rtpt that training has started
        self.rtpt.epoch_starts()

        if cur_it is None:
            self.cur_it = self._k_steps_initial * 1000
        else:
            self.cur_it = cur_it
        self.nb_spikes = 0  # count the number of spikes that have been detected
        # initialize the loss to compare with, with a very high value
        self.old_val_loss = 9000
        self.graph_exported = False  # create a state variable to check if the net architecture has been reported yet
        self.continue_training = True
        self.optimizer.lr = self._lr_schedule(self.cur_it)
        if self.optimizer_name == "nag":
            self.optimizer.momentum = self._momentum_schedule(self.cur_it)

        if not self.ordering:  # safety check to prevent eternal loop
            raise Exception(
                "You must have at least one part file in your planes-dataset directory!"
            )

        while self.continue_training:  # Too many nested blocks (7/5)
            # reshuffle the ordering of the training game batches (shuffle works in place)
            random.shuffle(self.ordering)

            epoch += 1
            logging.info("EPOCH %d", epoch)
            logging.info("=========================")
            self.t_s_steps = time()
            self._model.init_optimizer(optimizer=self.optimizer)

            for part_id in tqdm_notebook(self.ordering):

                # load one chunk of the dataset from memory
                _, self.x_train, self.yv_train, self.yp_train, plys_to_end, _ = load_pgn_dataset(
                    dataset_type="train",
                    part_id=part_id,
                    normalize=self._normalize,
                    verbose=False,
                    q_value_ratio=self._q_value_ratio)
                # fill_up_batch if there aren't enough games
                if len(self.yv_train) < self._batch_size:
                    logging.info("filling up batch with too few samples %d" %
                                 len(self.yv_train))
                    self.x_train = fill_up_batch(self.x_train,
                                                 self._batch_size)
                    self.yv_train = fill_up_batch(self.yv_train,
                                                  self._batch_size)
                    self.yp_train = fill_up_batch(self.yp_train,
                                                  self._batch_size)
                    if plys_to_end is not None:
                        plys_to_end = fill_up_batch(plys_to_end,
                                                    self._batch_size)

                if self.discount != 1:
                    self.yv_train *= self.discount**plys_to_end

                self.yp_train = prepare_policy(self.yp_train,
                                               self.select_policy_from_plane,
                                               self.sparse_policy_label)

                self._train_iter = mx.io.NDArrayIter(
                    {'data': self.x_train}, {
                        'value_label': self.yv_train,
                        'policy_label': self.yp_train
                    },
                    self._batch_size,
                    shuffle=True)

                # avoid memory leaks by adding synchronization
                mx.nd.waitall()

                reset_metrics(self._metrics)
                for batch in self._train_iter:
                    self._model.forward(batch,
                                        is_train=True)  # compute predictions
                    for metric in self._metrics:  # update the metrics
                        self._model.update_metric(metric, batch.label)

                    self._model.backward()
                    # compute gradients
                    self._model.update()  # update parameters
                    self.batch_callback()

                    if not self.continue_training:
                        logging.info('Elapsed time for training(hh:mm:ss): ' +
                                     str(
                                         datetime.timedelta(
                                             seconds=round(time() -
                                                           self.t_s))))

                        return self._return_metrics_and_stop_training()

                # add the graph representation of the network to the tensorboard log file
                if not self.graph_exported and self._log_metrics_to_tensorboard:
                    # self.sum_writer.add_graph(self._symbol)
                    self.graph_exported = True

    def _return_metrics_and_stop_training(self):
        return (self.k_steps, self.val_metric_values["value_loss"],
                self.val_metric_values["policy_loss"],
                self.val_metric_values["value_acc_sign"], self.val_metric_values["policy_acc"]), \
               (self.k_steps_best, self.val_loss_best, self.val_p_acc_best)

    def _fill_train_metrics(self):
        """
        Fills in the training metrics
        :return:
        """
        self.train_metric_values = {}
        for metric in self._metrics:
            name, value = metric.get()
            self.train_metric_values[name] = value

        self.train_metric_values["loss"] = 0.01 * self.train_metric_values["value_loss"] + \
                                           0.99 * self.train_metric_values["policy_loss"]

    def recompute_eval(self):
        """
        Recomputes the score on the validataion data
        :return:
        """
        ms_step = ((time() - self.t_s_steps) / self._batch_steps) * 1000
        logging.info("Step %dK/%dK - %dms/step", self.k_steps,
                     self.k_steps_end, ms_step)
        logging.info("-------------------------")
        logging.debug("Iteration %d/%d", self.cur_it, self._total_it)
        self.rtpt.epoch_ends()  # update proctitle
        if self.optimizer_name == "nag":
            logging.debug("lr: %.7f - momentum: %.7f", self.optimizer.lr,
                          self.optimizer.momentum)
        else:
            logging.debug("lr: %.7f - momentum: -", self.optimizer.lr)

        # the metric values have already been computed during training for the train set
        self._fill_train_metrics()

        self.val_metric_values = evaluate_metrics(
            self._metrics,
            self._val_iter,
            self._model,
        )
        if self._use_spike_recovery and (
                self.old_val_loss * self._spike_thresh <
                self.val_metric_values["loss"] or np.isnan(
                    self.val_metric_values["loss"])):  # check for spikes
            self.handle_spike()
        else:
            self.update_eval()

    def handle_spike(self):
        """
        Handles the occurence of a spike during training, in the case validation loss increased dramatically.
        :return: self._return_metrics_and_stop_training()
        """
        self.nb_spikes += 1
        logging.warning(
            "Spike %d/%d occurred - val_loss: %.3f",
            self.nb_spikes,
            self._max_spikes,
            self.val_metric_values["loss"],
        )
        if self.nb_spikes >= self._max_spikes:
            val_loss = self.val_metric_values["loss"]
            val_p_acc = self.val_metric_values["policy_acc"]
            # finally stop training because the number of lr drops has been achieved

            logging.debug(
                "The maximum number of spikes has been reached. Stop training."
            )
            self.continue_training = False

            if self._log_metrics_to_tensorboard:
                self.sum_writer.close()
            return self._return_metrics_and_stop_training()

        logging.debug("Recover to latest checkpoint")
        # Load the best model once again
        prefix = "%s/weights/model-%.5f-%.3f" % (self.cwd, self.val_loss_best,
                                                 self.val_p_acc_best)

        logging.debug("load current best model:%s", prefix)
        # self._net.load_parameters(model_path, ctx=self._ctx)
        self._model.load(prefix, epoch=self.k_steps_best)

        self.k_steps = self.k_steps_best
        logging.debug("k_step is back at %d", self.k_steps_best)
        # print the elapsed time
        t_delta = time() - self.t_s_steps
        print(" - %.ds" % t_delta)
        self.t_s_steps = time()

    def update_eval(self):
        """
        Updates the evaluation metrics
        :return:
        """
        # update the val_loss_value to compare with using spike recovery
        self.old_val_loss = self.val_metric_values["loss"]
        # log the metric values to tensorboard
        self._log_metrics(self.train_metric_values,
                          global_step=self.k_steps,
                          prefix="train_")
        self._log_metrics(self.val_metric_values,
                          global_step=self.k_steps,
                          prefix="val_")

        # check if a new checkpoint shall be created
        if self.val_loss_best is None or self.val_metric_values[
                "loss"] < self.val_loss_best:
            # update val_loss_best
            self.val_loss_best = self.val_metric_values["loss"]
            self.val_p_acc_best = self.val_metric_values["policy_acc"]
            self.k_steps_best = self.k_steps

            if self._export_weights:
                prefix = "%s/weights/model-%.5f-%.3f" % (
                    self.cwd, self.val_loss_best, self.val_p_acc_best)
                # the export function saves both the architecture and the weights
                print()
                self._model.save_checkpoint(prefix, epoch=self.k_steps_best)

            self.patience_cnt = 0  # reset the patience counter
        # print the elapsed time
        t_delta = time() - self.t_s_steps
        print(" - %.ds" % t_delta)
        self.t_s_steps = time()

        # log the samples per second metric to tensorboard
        self.sum_writer.add_scalar(
            tag="samples_per_second",
            value={
                "hybrid_sync": self._batch_size * self._batch_steps / t_delta
            },
            global_step=self.k_steps,
        )

        # log the current learning rate
        self.sum_writer.add_scalar(tag="lr",
                                   value=self._lr_schedule(self.cur_it),
                                   global_step=self.k_steps)
        if self.optimizer_name == "nag":
            # log the current momentum value
            self.sum_writer.add_scalar(tag="momentum",
                                       value=self._momentum_schedule(
                                           self.cur_it),
                                       global_step=self.k_steps)

        if self.cur_it >= self._total_it:

            self.continue_training = False

            self.val_loss = self.val_metric_values["loss"]
            self.val_p_acc = self.val_metric_values["policy_acc"]
            # finally stop training because the number of lr drops has been achieved
            logging.debug("The number of given iterations has been reached")

            if self._log_metrics_to_tensorboard:
                self.sum_writer.close()

    def batch_callback(self):
        """
        Callback which is executed after every batch to update the momentum and learning rate
        :return:
        """

        # update the learning rate and momentum
        self.optimizer.lr = self._lr_schedule(self.cur_it)
        if self.optimizer_name == "nag":
            self.optimizer.momentum = self._momentum_schedule(self.cur_it)

        self.cur_it += 1
        self.batch_proc_tmp += 1

        if self.batch_proc_tmp >= self._batch_steps:  # show metrics every thousands steps
            self.batch_proc_tmp = self.batch_proc_tmp - self._batch_steps
            # update the counters
            self.k_steps += 1
            self.patience_cnt += 1
            self.recompute_eval()
            self.custom_metric_eval()

    def custom_metric_eval(self):
        """
        Evaluates the model based on the validation set of different variants
        """

        if self.variant_metrics is None:
            return

        for part_id, variant_name in enumerate(self.variant_metrics):
            # load one chunk of the dataset from memory
            _, x_val, yv_val, yp_val, _, _ = load_pgn_dataset(
                dataset_type="val",
                part_id=part_id,
                normalize=self._normalize,
                verbose=False,
                q_value_ratio=self._q_value_ratio)

            if self.select_policy_from_plane:
                val_iter = mx.io.NDArrayIter({'data': x_val}, {
                    'value_label':
                    yv_val,
                    'policy_label':
                    np.array(FLAT_PLANE_IDX)[yp_val.argmax(axis=1)]
                }, self._batch_size)
            else:
                val_iter = mx.io.NDArrayIter(
                    {'data': x_val}, {
                        'value_label': yv_val,
                        'policy_label': yp_val.argmax(axis=1)
                    }, self._batch_size)

            results = self._model.score(val_iter, self._metrics)
            prefix = "val_"

            for entry in results:
                name = variant_name + "_" + entry[0]
                value = entry[1]
                print(" - %s%s: %.4f" % (prefix, name, value), end="")
                # add the metrics to the tensorboard event file
                if self._log_metrics_to_tensorboard:
                    self.sum_writer.add_scalar(
                        name, [prefix.replace("_", ""), value], self.k_steps)
        print()
Ejemplo n.º 23
0
def test_add_histogram():
    shape = rand_shape_nd(4)
    sw = SummaryWriter(logdir=_LOGDIR)
    sw.add_histogram(tag='test_add_histogram', values=mx.nd.random.normal(shape=shape), global_step=0, bins=100)
    sw.close()
    check_event_file_and_remove_logdir()
Ejemplo n.º 24
0
def train(epochs, ctx):
    # Collect all parameters from net and its children, then initialize them.
    net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
    net.hybridize()

    # Trainer is for updating parameters with gradient.
    trainer = gluon.Trainer(net.collect_params(), 'sgd',
                            {'learning_rate': opt.lr, 'momentum': opt.momentum})
    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    # collect parameter names for logging the gradients of parameters in each epoch
    params = net.collect_params()
    param_names = params.keys()

    # define a summary writer that logs data and flushes to the file every 5 seconds
    sw = SummaryWriter(logdir='./logs', flush_secs=5)

    global_step = 0
    for epoch in range(epochs):
        # reset data iterator and metric at begining of epoch.
        metric.reset()
        for i, (data, label) in enumerate(train_data):
            # Copy data to ctx if necessary
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)
            # Start recording computation graph with record() section.
            # Recorded graphs can then be differentiated with backward.
            with autograd.record():
                output = net(data)
                L = loss(output, label)
            sw.add_scalar(tag='cross_entropy', value=L.mean().asscalar(), global_step=global_step)
            global_step += 1
            L.backward()

            # take a gradient step with batch_size equal to data.shape[0]
            trainer.step(data.shape[0])
            # update metric at last.
            metric.update([label], [output])

            if i % opt.log_interval == 0 and i > 0:
                name, train_acc = metric.get()
                print('[Epoch %d Batch %d] Training: %s=%f' % (epoch, i, name, train_acc))

            # Log the first batch of images of each epoch
            if i == 0:
                sw.add_image('minist_first_minibatch', data.reshape((opt.batch_size, 1, 28, 28)), epoch)

        if epoch == 0:
            sw.add_graph(net)

        grads = [i.grad() for i in net.collect_params().values()]
        assert len(grads) == len(param_names)
        # logging the gradients of parameters for checking convergence
        for i, name in enumerate(param_names):
            sw.add_histogram(tag=name, values=grads[i], global_step=epoch, bins=1000)

        name, train_acc = metric.get()
        print('[Epoch %d] Training: %s=%f' % (epoch, name, train_acc))
        # logging training accuracy
        sw.add_scalar(tag='accuracy_curves', value=('train_acc', train_acc), global_step=epoch)

        name, val_acc = test(ctx)
        print('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc))
        # logging the validation accuracy
        sw.add_scalar(tag='accuracy_curves', value=('valid_acc', val_acc), global_step=epoch)

    sw.export_scalars('scalar_dict.json')
    sw.close()
Ejemplo n.º 25
0
def train_net(net, train_iter, valid_iter, batch_size, trainer, ctx,
              num_epochs, lr_sch, save_prefix):
    logger.info("===================START TRAINING====================")
    if use_mxboard:
        sw = SummaryWriter(logdir='logs', flush_secs=5)
    cls_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    cls_acc = mx.metric.Accuracy(name="train acc")
    top_acc = 0
    iter_num = 0
    #test_acc,test_loss = test_net(net, valid_iter, ctx)
    #sw.add_graph(net) #only hybrid block supported
    param_names = net.collect_params().keys()
    for epoch in range(num_epochs):
        train_loss = []
        t0 = time.time()
        if isinstance(train_iter, mx.io.MXDataIter):
            train_iter.reset()
        total = 0
        trainer.set_learning_rate(lr_sch(epoch))
        for batch in train_iter:
            iter_num += 1
            # print("iter ",iter_num," start")
            if isinstance(batch, mx.io.DataBatch):
                X, Y = batch.data[0], batch.label[0]
                #total += X.shape[0]
                #print(total)
            else:
                X, Y = batch
            #print(X.shape,Y.shape)
            #print(Y)
            X = X.as_in_context(ctx)
            Y = Y.as_in_context(ctx)
            with autograd.record(True):
                out = net(X)
                #out = out.as_in_context(mx.cpu())
                loss = cls_loss(out, Y)
        # print(out.asnumpy()[0])
        # print('loss = ',loss.sum().asscalar())
            loss.backward()
            train_loss.append(loss.sum().asscalar())
            trainer.step(batch_size)
            cls_acc.update(Y, out)
            nd.waitall()
            #print("iter ",iter_num," end")
            if use_mxboard:
                if iter_num % 100 == 0:
                    sw.add_scalar(tag='train_loss',
                                  value=loss.mean().asscalar(),
                                  global_step=iter_num)
                    sw.add_scalar(tag='train_acc',
                                  value=cls_acc.get(),
                                  global_step=iter_num)
                if iter_num % 100 == 0:
                    for name in net.collect_params():
                        param = net.collect_params()[name]
                        if param.grad_req != "null":
                            sw.add_histogram(tag=name,
                                             values=param.grad(),
                                             global_step=iter_num,
                                             bins=1000)

        logger.info("epoch {} lr {} {}sec".format(epoch, trainer.learning_rate,
                                                  time.time() - t0))
        train_loss, train_acc = np.mean(train_loss) / batch_size, cls_acc.get()
        logger.info("\ttrain loss {} {}".format(train_loss, train_acc))
        if epoch > 0 and (epoch % 10) == 0:
            test_acc, test_loss = test_net(net, valid_iter, ctx)
            if use_mxboard:
                sw.add_scalar(tag='test_acc',
                              value=test_acc,
                              global_step=epoch)
                sw.add_scalar(tag='test_loss',
                              value=test_loss,
                              global_step=epoch)
            if top_acc < test_acc:
                top_acc = test_acc
                logger.info('\ttop valid acc {}'.format(test_acc))
                if isinstance(net, mx.gluon.nn.HybridSequential) or isinstance(
                        net, mx.gluon.nn.HybridBlock):
                    pf = '{}_{:.3f}.params'.format(save_prefix, top_acc)
                    net.export(pf, epoch)
                else:
                    net_path = '{}top_acc_{}_{:.3f}.params'.format(
                        save_prefix, epoch, top_acc)
                    net.save_parameters(net_path)

    if use_mxboard:
        sw.close()
Ejemplo n.º 26
0
class TrainerAgentGluon:  # Probably needs refactoring
    """Main training loop"""
    def __init__(
        self,
        net,
        val_data,
        train_config: TrainConfig,
        train_objects: TrainObjects,
        use_rtpt: bool,
    ):
        """
        Class for training the neural network.
        :param net: The NN with loaded parameters that shall be trained.
        :param val_data: The validation data loaded with gluon DataLoader.
        :param train_config: An instance of the TrainConfig data class.
        :param train_objects: Am omstamce pf the TrainObject data class.
        :param use_rtpt: If True, an RTPT object will be created and modified within this class.
        """
        # Too many instance attributes (29/7) - Too many arguments (24/5) - Too many local variables (25/15)
        # Too few public methods (1/2)
        self.tc = train_config
        self.to = train_objects
        if self.to.metrics is None:
            self.to.metrics = {}
        self._ctx = get_context(train_config.context, train_config.device_id)
        self._net = net
        self._graph_exported = False
        self._val_data = val_data
        # define a summary writer that logs data and flushes to the file every 5 seconds
        if self.tc.log_metrics_to_tensorboard:
            self.sum_writer = SummaryWriter(logdir=self.tc.export_dir + "logs",
                                            flush_secs=5,
                                            verbose=False)
        # Define the two loss functions
        self._softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss(
            sparse_label=self.tc.sparse_policy_label)
        self._l2_loss = gluon.loss.L2Loss()
        if self.tc.optimizer_name != "nag":
            raise NotImplementedError(
                "The requested optimizer %s Isn't supported yet." %
                self.tc.optimizer_name)
        self._trainer = gluon.Trainer(
            self._net.collect_params(),
            "nag",
            {
                "learning_rate": self.to.lr_schedule(0),
                "momentum": self.to.momentum_schedule(0),
                "wd": self.tc.wd,
            },
        )

        # collect parameter names for logging the gradients of parameters in each epoch
        self._params = self._net.collect_params()
        self._param_names = self._params.keys()
        self.ordering = list(
            range(self.tc.nb_parts)
        )  # define a list which describes the order of the processed batches

        self.use_rtpt = use_rtpt
        self.rtpt = None  # Set this later in training function

    def _log_metrics(self, metric_values, global_step, prefix="train_"):
        """
        Logs a dictionary object of metric value to the console and to tensorboard
        if _log_metrics_to_tensorboard is set to true
        :param metric_values: Dictionary object storing the current metrics
        :param global_step: X-Position point of all metric entries
        :param prefix: Used for labelling the metrics
        :return:
        """
        for name in metric_values.keys():  # show the metric stats
            print(" - %s%s: %.4f" % (prefix, name, metric_values[name]),
                  end="")
            # add the metrics to the tensorboard event file
            if self.tc.log_metrics_to_tensorboard:
                self.sum_writer.add_scalar(
                    name, [prefix.replace("_", ""), metric_values[name]],
                    global_step)

    def _process_on_data_plane_file(self, train_data, batch_proc_tmp):

        for _, (data, value_label, policy_label) in enumerate(train_data):
            data = data.as_in_context(self._ctx)
            value_label = value_label.as_in_context(self._ctx)
            policy_label = policy_label.as_in_context(self._ctx)

            # update a dummy metric to see a proper progress bar
            #  (the metrics will get evaluated at the end of 100k steps)
            # if self.batch_proc_tmp > 0:
            #    self._metrics['value_loss'].update(old_label, value_out)
            # old_label = value_label
            with autograd.record():
                [value_out, policy_out] = self._net(data)
                if self.tc.select_policy_from_plane and not self.tc.is_policy_from_plane_data:
                    policy_out = policy_out[:, FLAT_PLANE_IDX]
                value_loss = self._l2_loss(value_out, value_label)
                policy_loss = self._softmax_cross_entropy(
                    policy_out, policy_label)
                # weight the components of the combined loss
                combined_loss = self.tc.val_loss_factor * value_loss.sum(
                ) + self.tc.policy_loss_factor * policy_loss.sum()
                # update a dummy metric to see a proper progress bar
                self.to.metrics["value_loss"].update(preds=value_out,
                                                     labels=value_label)

            combined_loss.backward()
            self._trainer.step(data.shape[0])
            batch_proc_tmp += 1
        return batch_proc_tmp, self.to.metrics["value_loss"].get()[1]

    def train(self, cur_it=None):  # Probably needs refactoring
        """
        Training model
        :param cur_it: Current iteration which is used for the learning rate and momentum schedule.
         If set to None it will be initialized
        """
        # Too many local variables (44/15) - Too many branches (18/12) - Too many statements (108/50)
        # set a custom seed for reproducibility
        random.seed(self.tc.seed)
        # define and initialize the variables which will be used
        t_s = time()
        # predefine the local variables that will be used in the training loop
        val_loss_best = val_p_acc_best = k_steps_best = val_metric_values_best = old_label = value_out = None
        patience_cnt = epoch = batch_proc_tmp = 0  # track on how many batches have been processed in this epoch
        k_steps = self.tc.k_steps_initial  # counter for thousands steps
        # calculate how many log states will be processed
        k_steps_end = round(self.tc.total_it / self.tc.batch_steps)
        # we use k-steps instead of epochs here
        if k_steps_end == 0:
            k_steps_end = 1

        if self.use_rtpt:
            self.rtpt = RTPT(name_initials=self.tc.name_initials,
                             experiment_name='crazyara',
                             max_iterations=k_steps_end -
                             self.tc.k_steps_initial)
        if cur_it is None:
            cur_it = self.tc.k_steps_initial * 1000
        nb_spikes = 0  # count the number of spikes that have been detected
        # initialize the loss to compare with, with a very high value
        old_val_loss = np.inf
        graph_exported = False  # create a state variable to check if the net architecture has been reported yet

        if not self.ordering:  # safety check to prevent eternal loop
            raise Exception(
                "You must have at least one part file in your planes-dataset directory!"
            )

        if self.use_rtpt:
            # Start the RTPT tracking
            self.rtpt.start()

        while True:  # Too many nested blocks (7/5)
            # reshuffle the ordering of the training game batches (shuffle works in place)
            random.shuffle(self.ordering)

            epoch += 1
            logging.info("EPOCH %d", epoch)
            logging.info("=========================")
            t_s_steps = time()

            for part_id in tqdm_notebook(self.ordering):
                # load one chunk of the dataset from memory
                _, x_train, yv_train, yp_train, _, _ = load_pgn_dataset(
                    dataset_type="train",
                    part_id=part_id,
                    normalize=self.tc.normalize,
                    verbose=False,
                    q_value_ratio=self.tc.q_value_ratio)

                yp_train = prepare_policy(
                    y_policy=yp_train,
                    select_policy_from_plane=self.tc.select_policy_from_plane,
                    sparse_policy_label=self.tc.sparse_policy_label,
                    is_policy_from_plane_data=self.tc.is_policy_from_plane_data
                )

                # update the train_data object
                train_dataset = gluon.data.ArrayDataset(
                    nd.array(x_train), nd.array(yv_train), nd.array(yp_train))
                train_data = gluon.data.DataLoader(
                    train_dataset,
                    batch_size=self.tc.batch_size,
                    shuffle=True,
                    num_workers=self.tc.cpu_count)

                for _, (data, value_label,
                        policy_label) in enumerate(train_data):
                    data = data.as_in_context(self._ctx)
                    value_label = value_label.as_in_context(self._ctx)
                    policy_label = policy_label.as_in_context(self._ctx)

                    # update a dummy metric to see a proper progress bar
                    #  (the metrics will get evaluated at the end of 100k steps)
                    if batch_proc_tmp > 0:
                        self.to.metrics["value_loss"].update(
                            old_label, value_out)

                    old_label = value_label
                    with autograd.record():
                        [value_out, policy_out] = self._net(data)
                        value_loss = self._l2_loss(value_out, value_label)
                        policy_loss = self._softmax_cross_entropy(
                            policy_out, policy_label)
                        # weight the components of the combined loss
                        combined_loss = (
                            self.tc.val_loss_factor * value_loss +
                            self.tc.policy_loss_factor * policy_loss)
                        # update a dummy metric to see a proper progress bar
                        # self._metrics['value_loss'].update(preds=value_out, labels=value_label)

                    combined_loss.backward()
                    learning_rate = self.to.lr_schedule(
                        cur_it)  # update the learning rate
                    self._trainer.set_learning_rate(learning_rate)
                    momentum = self.to.momentum_schedule(
                        cur_it)  # update the momentum
                    self._trainer._optimizer.momentum = momentum
                    self._trainer.step(data.shape[0])
                    cur_it += 1
                    batch_proc_tmp += 1
                    # add the graph representation of the network to the tensorboard log file
                    if not graph_exported and self.tc.log_metrics_to_tensorboard:
                        self.sum_writer.add_graph(self._net)
                        graph_exported = True

                    if batch_proc_tmp >= self.tc.batch_steps:  # show metrics every thousands steps
                        # log the current learning rate
                        # update batch_proc_tmp counter by subtracting the batch_steps
                        batch_proc_tmp = batch_proc_tmp - self.tc.batch_steps
                        ms_step = (
                            (time() - t_s_steps) /
                            self.tc.batch_steps) * 1000  # measure elapsed time
                        # update the counters
                        k_steps += 1
                        patience_cnt += 1
                        logging.info("Step %dK/%dK - %dms/step", k_steps,
                                     k_steps_end, ms_step)
                        logging.info("-------------------------")
                        logging.debug("Iteration %d/%d", cur_it,
                                      self.tc.total_it)
                        logging.debug("lr: %.7f - momentum: %.7f",
                                      learning_rate, momentum)
                        train_metric_values = evaluate_metrics(
                            self.to.metrics,
                            train_data,
                            self._net,
                            nb_batches=10,  #25,
                            ctx=self._ctx,
                            sparse_policy_label=self.tc.sparse_policy_label,
                            apply_select_policy_from_plane=self.tc.
                            select_policy_from_plane
                            and not self.tc.is_policy_from_plane_data)
                        val_metric_values = evaluate_metrics(
                            self.to.metrics,
                            self._val_data,
                            self._net,
                            nb_batches=None,
                            ctx=self._ctx,
                            sparse_policy_label=self.tc.sparse_policy_label,
                            apply_select_policy_from_plane=self.tc.
                            select_policy_from_plane
                            and not self.tc.is_policy_from_plane_data)
                        if self.use_rtpt:
                            # update process title according to loss
                            self.rtpt.step(
                                subtitle=
                                f"loss={val_metric_values['loss']:2.2f}")
                        if self.tc.use_spike_recovery and (
                                old_val_loss * self.tc.spike_thresh <
                                val_metric_values["loss"]
                                or np.isnan(val_metric_values["loss"])
                        ):  # check for spikes
                            nb_spikes += 1
                            logging.warning(
                                "Spike %d/%d occurred - val_loss: %.3f",
                                nb_spikes,
                                self.tc.max_spikes,
                                val_metric_values["loss"],
                            )
                            if nb_spikes >= self.tc.max_spikes:
                                val_loss = val_metric_values["loss"]
                                val_p_acc = val_metric_values["policy_acc"]
                                logging.debug(
                                    "The maximum number of spikes has been reached. Stop training."
                                )
                                # finally stop training because the number of lr drops has been achieved
                                print()
                                print("Elapsed time for training(hh:mm:ss): " +
                                      str(
                                          datetime.timedelta(
                                              seconds=round(time() - t_s))))

                                if self.tc.log_metrics_to_tensorboard:
                                    self.sum_writer.close()
                                return return_metrics_and_stop_training(
                                    k_steps, val_metric_values, k_steps_best,
                                    val_metric_values_best)

                            logging.debug("Recover to latest checkpoint")
                            model_path = self.tc.export_dir + "weights/model-%.5f-%.3f-%04d.params" % (
                                val_loss_best,
                                val_p_acc_best,
                                k_steps_best,
                            )  # Load the best model once again
                            logging.debug("load current best model:%s",
                                          model_path)
                            self._net.load_parameters(model_path,
                                                      ctx=self._ctx)
                            k_steps = k_steps_best
                            logging.debug("k_step is back at %d", k_steps_best)
                            # print the elapsed time
                            t_delta = time() - t_s_steps
                            print(" - %.ds" % t_delta)
                            t_s_steps = time()
                        else:
                            # update the val_loss_value to compare with using spike recovery
                            old_val_loss = val_metric_values["loss"]
                            # log the metric values to tensorboard
                            self._log_metrics(train_metric_values,
                                              global_step=k_steps,
                                              prefix="train_")
                            self._log_metrics(val_metric_values,
                                              global_step=k_steps,
                                              prefix="val_")

                            if self.tc.export_grad_histograms:
                                grads = []
                                # logging the gradients of parameters for checking convergence
                                for _, name in enumerate(self._param_names):
                                    if "bn" not in name and "batch" not in name and name != "policy_flat_plane_idx":
                                        grads.append(self._params[name].grad())
                                        self.sum_writer.add_histogram(
                                            tag=name,
                                            values=grads[-1],
                                            global_step=k_steps,
                                            bins=20)

                            # check if a new checkpoint shall be created
                            if val_loss_best is None or val_metric_values[
                                    "loss"] < val_loss_best:
                                # update val_loss_best
                                val_loss_best = val_metric_values["loss"]
                                val_p_acc_best = val_metric_values[
                                    "policy_acc"]
                                val_metric_values_best = val_metric_values
                                k_steps_best = k_steps

                                if self.tc.export_weights:
                                    prefix = self.tc.export_dir + "weights/model-%.5f-%.3f" \
                                             % (val_loss_best, val_p_acc_best)
                                    # the export function saves both the architecture and the weights
                                    self._net.export(prefix,
                                                     epoch=k_steps_best)
                                    print()
                                    logging.info(
                                        "Saved checkpoint to %s-%04d.params",
                                        prefix, k_steps_best)

                                patience_cnt = 0  # reset the patience counter
                            # print the elapsed time
                            t_delta = time() - t_s_steps
                            print(" - %.ds" % t_delta)
                            t_s_steps = time()

                            # log the samples per second metric to tensorboard
                            self.sum_writer.add_scalar(
                                tag="samples_per_second",
                                value={
                                    "hybrid_sync":
                                    data.shape[0] * self.tc.batch_steps /
                                    t_delta
                                },
                                global_step=k_steps,
                            )

                            # log the current learning rate
                            self.sum_writer.add_scalar(
                                tag="lr",
                                value=self.to.lr_schedule(cur_it),
                                global_step=k_steps)
                            # log the current momentum value
                            self.sum_writer.add_scalar(
                                tag="momentum",
                                value=self.to.momentum_schedule(cur_it),
                                global_step=k_steps)

                            if cur_it >= self.tc.total_it:

                                val_loss = val_metric_values["loss"]
                                val_p_acc = val_metric_values["policy_acc"]
                                logging.debug(
                                    "The number of given iterations has been reached"
                                )
                                # finally stop training because the number of lr drops has been achieved
                                print()
                                print("Elapsed time for training(hh:mm:ss): " +
                                      str(
                                          datetime.timedelta(
                                              seconds=round(time() - t_s))))

                                if self.tc.log_metrics_to_tensorboard:
                                    self.sum_writer.close()

                                return return_metrics_and_stop_training(
                                    k_steps, val_metric_values, k_steps_best,
                                    val_metric_values_best)
Ejemplo n.º 27
0
def mytrain(net,num_classes,train_data,valid_data,ctx,start_epoch, end_epoch, \
            arm_cls_loss=arm_cls_loss,cls_loss=cls_loss,box_loss=box_loss,trainer=None):
    if trainer is None:
        # trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01,'momentum':0.9, 'wd':50.0})
        trainer = gluon.Trainer(net.collect_params(), 'adam', {
            'learning_rate': 0.001,
            'clip_gradient': 2.0
        })
        # trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': 0.003})
    box_metric = metric.MAE()

    ## add visible
    # collect parameter names for logging the gradients of parameters in each epoch
    params = net.collect_params()
    # param_names = params.keys()
    # define a summary writer that logs data and flushes to the file every 5 seconds
    sw = SummaryWriter(logdir='./logs', flush_secs=5)
    global_step = 0

    for e in range(start_epoch, end_epoch):
        # print(e)
        train_data.reset()
        valid_data.reset()
        box_metric.reset()
        tic = time.time()
        _loss = [0, 0]
        arm_loss = [0, 0]
        # if e == 6 or e == 100:
        #     trainer.set_learning_rate(trainer.learning_rate * 0.2)

        outs, labels = None, None
        for i, batch in enumerate(train_data):
            # print('----- batch {} start ----'.format(i))
            data = batch.data[0].as_in_context(ctx)
            label = batch.label[0].as_in_context(ctx)
            # print('label shape: ',label.shape)
            with autograd.record():
                # 1. generate results according to extract network
                ssd_layers = net(data)
                arm_loc_preds, arm_cls_preds, arm_anchor_boxes, odm_loc_preds, odm_cls_preds = multibox_layer(ssd_layers,\
                                                                            num_classes,sizes,ratios,normalizations)
                # arm_loc_preds, arm_cls_preds, arm_anchor_boxes, odm_loc_preds, odm_cls_preds = net(data)
                # print('---------1111-----------')
                # 2. ARM predict
                ## 2.1  modify label as [-1,0,..]
                label_arm = nd.Custom(label, op_type='modify_label')
                arm_tmp = MultiBoxTarget(arm_anchor_boxes,label_arm,arm_cls_preds,overlap_threshold=.5,\
                                         negative_mining_ratio=3,negative_mining_thresh=.5)
                arm_loc_target = arm_tmp[0]  # box offset
                arm_loc_target_mask = arm_tmp[1]  # box mask (only 0,1)
                arm_cls_target = arm_tmp[2]  #  every anchor' idx
                # print(sum(arm_cls_target[0]))
                # print('---------2222-----------')

                # 3. ODM predict
                ## 3.1 refine anchor generator originate in ARM
                odm_anchor_boxes = refine_anchor_generator(
                    arm_anchor_boxes,
                    arm_loc_preds)  #(batch,h*w*num_anchors[:layers],4)
                # ### debug backward err
                # odm_anchor_boxes = arm_anchor_boxes
                odm_anchor_boxes_bs = nd.split(
                    data=odm_anchor_boxes, axis=0,
                    num_outputs=label.shape[0])  # list
                # print('---3 : odm_anchor_boxes_bs shape : {}'.format(odm_anchor_boxes_bs[0].shape))
                # print('---------3333-----------')
                ## 3.2 对当前所有batch的data计算 Target (多个gpu使用)

                odm_loc_target = []
                odm_loc_target_mask = []
                odm_cls_target = []
                label_bs = nd.split(data=label,
                                    axis=0,
                                    num_outputs=label.shape[0])
                odm_cls_preds_bs = nd.split(data=odm_cls_preds,
                                            axis=0,
                                            num_outputs=label.shape[0])
                # print('---4 : odm_cls_preds_bs shape: {}'.format(odm_cls_preds_bs[0].shape))
                # print('---4 : label_bs shape: {}'.format(label_bs[0].shape))

                for j in range(label.shape[0]):
                    if label.shape[0] == 1:
                        odm_tmp = MultiBoxTarget(odm_anchor_boxes_bs[j].expand_dims(axis=0),label_bs[j].expand_dims(axis=0),\
                                            odm_cls_preds_bs[j].expand_dims(axis=0),overlap_threshold=.5,negative_mining_ratio=2,negative_mining_thresh=.5)
                    ## 多个batch
                    else:
                        odm_tmp = MultiBoxTarget(odm_anchor_boxes_bs[j],label_bs[j],\
                                            odm_cls_preds_bs[j],overlap_threshold=.5,negative_mining_ratio=3,negative_mining_thresh=.5)
                    odm_loc_target.append(odm_tmp[0])
                    odm_loc_target_mask.append(odm_tmp[1])
                    odm_cls_target.append(odm_tmp[2])
                ### concat ,上面为什么会单独计算每张图,odm包含了batch,so需要拆
                odm_loc_target = nd.concat(*odm_loc_target, dim=0)
                odm_loc_target_mask = nd.concat(*odm_loc_target_mask, dim=0)
                odm_cls_target = nd.concat(*odm_cls_target, dim=0)

                # 4. negitave filter
                group = nd.Custom(arm_cls_preds,
                                  odm_cls_target,
                                  odm_loc_target_mask,
                                  op_type='negative_filtering')
                odm_cls_target = group[0]  #用ARM中的cls过滤后的odm_cls
                odm_loc_target_mask = group[1]  #过滤掉的mask为0
                # print('---------4444-----------')
                # 5. calc loss
                # TODO:add 1/N_arm, 1/N_odm (num of positive anchors)
                # arm_cls_loss = gluon.loss.SoftmaxCrossEntropyLoss()
                arm_loss_cls = arm_cls_loss(arm_cls_preds.transpose((0, 2, 1)),
                                            arm_cls_target)
                arm_loss_loc = box_loss(arm_loc_preds, arm_loc_target,
                                        arm_loc_target_mask)
                # print('55555 loss->  arm_loss_cls : {} arm_loss_loc {}'.format(arm_loss_cls.shape,arm_loss_loc.shape))
                # print('arm_loss_cls loss : {}'.format(arm_loss_cls))
                # odm_cls_prob = nd.softmax(odm_cls_preds,axis=2)
                tmp = odm_cls_preds.transpose((0, 2, 1))
                odm_loss_cls = cls_loss(odm_cls_preds.transpose((0, 2, 1)),
                                        odm_cls_target)
                odm_loss_loc = box_loss(odm_loc_preds, odm_loc_target,
                                        odm_loc_target_mask)
                # print('66666 loss->  odm_loss_cls : {} odm_loss_loc {}'.format(odm_loss_cls.shape,odm_loss_loc.shape))
                # print('odm_loss_cls loss :{} '.format(odm_loss_cls))
                # print('odm_loss_loc loss :{} '.format(odm_loss_loc))
                # print('N_arm: {} ; N_odm: {} '.format(nd.sum(arm_loc_target_mask,axis=1)/4.0,nd.sum(odm_loc_target_mask,axis=1)/4.0))
                # loss = arm_loss_cls+arm_loss_loc+odm_loss_cls+odm_loss_loc
                loss = 1/(nd.sum(arm_loc_target_mask,axis=1)/4.0) *(arm_loss_cls+arm_loss_loc) + \
                        1/(nd.sum(odm_loc_target_mask,axis=1)/4.0)*(odm_loss_cls+odm_loss_loc)

            sw.add_scalar(tag='loss',
                          value=loss.mean().asscalar(),
                          global_step=global_step)
            global_step += 1
            loss.backward(retain_graph=False)
            # autograd.backward(loss)
            # print(net.collect_params().get('conv4_3_weight').data())
            # print(net.collect_params().get('vgg0_conv9_weight').grad())
            ### 单独测试梯度
            # arm_loss_cls.backward(retain_graph=False)
            # arm_loss_loc.backward(retain_graph=False)
            # odm_loss_cls.backward(retain_graph=False)
            # odm_loss_loc.backward(retain_graph=False)

            trainer.step(data.shape[0])
            _loss[0] += nd.mean(odm_loss_cls).asscalar()
            _loss[1] += nd.mean(odm_loss_loc).asscalar()
            arm_loss[0] += nd.mean(arm_loss_cls).asscalar()
            arm_loss[1] += nd.mean(arm_loss_loc).asscalar()
            # print(arm_loss)
            arm_cls_prob = nd.SoftmaxActivation(arm_cls_preds, mode='channel')
            odm_cls_prob = nd.SoftmaxActivation(odm_cls_preds, mode='channel')
            out = MultiBoxDetection(odm_cls_prob,odm_loc_preds,odm_anchor_boxes,\
                                        force_suppress=True,clip=False,nms_threshold=.5,nms_topk=400)
            # print('out shape: {}'.format(out.shape))
            if outs is None:
                outs = out
                labels = label
            else:
                outs = nd.concat(outs, out, dim=0)
                labels = nd.concat(labels, label, dim=0)
            box_metric.update([odm_loc_target],
                              [odm_loc_preds * odm_loc_target_mask])
        print('-------{} epoch end ------'.format(e))
        train_AP = evaluate_MAP(outs, labels)
        valid_AP, val_box_metric = evaluate_acc(net, valid_data, ctx)
        info["train_ap"].append(train_AP)
        info["valid_ap"].append(valid_AP)
        info["loss"].append(_loss)
        print('odm loss: ', _loss)
        print('arm loss: ', arm_loss)
        if e == 0:
            sw.add_graph(net)
        # grads = [i.grad() for i in net.collect_params().values()]
        # grads_4_3 = net.collect_params().get('vgg0_conv9_weight').grad()
        # sw.add_histogram(tag ='vgg0_conv9_weight',values=grads_4_3,global_step=e, bins=1000 )
        grads_4_2 = net.collect_params().get('vgg0_conv5_weight').grad()
        sw.add_histogram(tag='vgg0_conv5_weight',
                         values=grads_4_2,
                         global_step=e,
                         bins=1000)
        # assert len(grads) == len(param_names)
        # logging the gradients of parameters for checking convergence
        # for i, name in enumerate(param_names):
        #     sw.add_histogram(tag=name, values=grads[i], global_step=e, bins=1000)

        # net.export('./Model/RefineDet_MeterDetect') # net
        if (e + 1) % 5 == 0:
            print(
                "epoch: %d time: %.2f cls loss: %.4f,reg loss: %.4f lr: %.5f" %
                (e, time.time() - tic, _loss[0], _loss[1],
                 trainer.learning_rate))
            print("train mae: %.4f AP: %.4f" % (box_metric.get()[1], train_AP))
            print("valid mae: %.4f AP: %.4f" %
                  (val_box_metric.get()[1], valid_AP))
        sw.add_scalar(tag='train_AP', value=train_AP, global_step=e)
        sw.add_scalar(tag='valid_AP', value=valid_AP, global_step=e)
    sw.close()
    if True:
        info["loss"] = np.array(info["loss"])
        info["cls_loss"] = info["loss"][:, 0]
        info["box_loss"] = info["loss"][:, 1]

        plt.figure(figsize=(12, 4))
        plt.subplot(121)
        plot("train_ap")
        plot("valid_ap")
        plt.legend(loc="upper right")
        plt.subplot(122)
        plot("cls_loss")
        plot("box_loss")
        plt.legend(loc="upper right")
        plt.savefig('loss_curve.png')
Ejemplo n.º 28
0
def train():
    # load_data
    batch_size = args.batch_size * max(args.num_gpus, 1)
    train_set = gluon.data.vision.CIFAR10(train=True,
                                          transform=transform_train)
    train_data = DataLoader(train_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            last_batch='discard')
    val_set = gluon.data.vision.CIFAR10(train=False, transform=transform_val)
    val_data = DataLoader(val_set,
                          batch_size=batch_size,
                          shuffle=False,
                          num_workers=args.num_workers)

    # set the network and trainer
    ctx = [mx.gpu(i)
           for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()]
    net = get_attention_cifar(10, num_layers=args.num_layers)
    net.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
    net.hybridize()

    trainer = gluon.Trainer(net.collect_params(), 'sgd', {
        'learning_rate': args.lr,
        'momentum': args.momentum,
        'wd': args.wd
    })
    cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss(
        sparse_label=not use_mix_up)
    train_metric = mtc.Accuracy() if not use_mix_up else mx.metric.RMSE()

    # set log output
    train_mode = 'MixUP' if use_mix_up else 'Vanilla'
    logger = logging.getLogger('TRAIN')
    logger.setLevel("INFO")
    logger.addHandler(logging.StreamHandler())
    logger.addHandler(
        logging.FileHandler(
            os.path.join(
                args.log_dir, 'text/cifar10_attention%d_%s_%s.log' %
                (args.num_layers, train_mode,
                 datetime.strftime(datetime.now(), '%Y%m%d%H%M')))))
    sw = SummaryWriter(logdir=os.path.join(
        args.log_dir, 'board/cifar10_attention%d_%s_%s' %
        (args.num_layers, train_mode,
         datetime.strftime(datetime.now(), '%Y%m%d%H%M'))),
                       verbose=False)

    # record the training hyper parameters
    logger.info(args)
    lr_counter = 0
    lr_steps = [int(s) for s in args.lr_steps.strip().split(',')]
    num_batch = len(train_data)
    epochs = args.epochs + 1
    alpha = args.alpha
    max_accuracy = 0.9

    for epoch in range(epochs):
        if epoch == lr_steps[lr_counter]:
            trainer.set_learning_rate(trainer.learning_rate * 0.1)
            if lr_counter + 1 < len(lr_steps):
                lr_counter += 1
        train_loss = 0
        train_metric.reset()
        tic = time.time()
        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            labels = gluon.utils.split_and_load(batch[1],
                                                ctx_list=ctx,
                                                batch_axis=0,
                                                even_split=False)

            if use_mix_up and epoch < epochs - 20:
                lam = np.random.beta(alpha, alpha)
                data = [lam * X + (1 - lam) * X[::-1] for X in data]
                labels = [lam * Y + (1 - lam) * Y[::-1] for Y in labels]

            with ag.record():
                outputs = [net(X) for X in data]
                losses = [
                    cross_entropy(yhat, y) for yhat, y in zip(outputs, labels)
                ]

            for l in losses:
                ag.backward(l)

            trainer.step(batch_size)
            train_metric.update(labels, outputs)

            train_loss += sum([l.mean().asscalar()
                               for l in losses]) / len(losses)

        _, train_acc = train_metric.get()
        train_loss /= num_batch
        val_acc, val_loss = validate(net, val_data, ctx)

        sw.add_scalar("AttentionNet/Loss", {
            'train': train_loss,
            'val': val_loss
        }, epoch)
        sw.add_scalar("AttentionNet/Metric", {
            'train': train_acc,
            'val': val_acc
        }, epoch)
        logger.info('[Epoch %d] train metric: %.6f, train loss: %.6f | '
                    'val accuracy: %.6f, val loss: %.6f, time: %.1f' %
                    (epoch, train_acc, train_loss, val_acc, val_loss,
                     time.time() - tic))

        if (epoch % args.save_period) == 0 and epoch != 0:
            net.save_parameters(
                "./models/attention%d-cifar10-epoch-%d-%s.params" %
                (args.num_layers, epoch, train_mode))

        if val_acc > max_accuracy:
            net.save_parameters(
                "./models/best-%f-attention%d-cifar10-epoch-%d-%s.params" %
                (val_acc, args.num_layers, epoch, train_mode))
            max_accuracy = val_acc

    sw.close()
    logger.info("Train End.")
Ejemplo n.º 29
0
def main():
    opt = parse_args()

    makedirs(opt.save_dir)

    filehandler = logging.FileHandler(
        os.path.join(opt.save_dir, opt.logging_file))
    streamhandler = logging.StreamHandler()
    logger = logging.getLogger('')
    logger.setLevel(logging.INFO)
    logger.addHandler(filehandler)
    logger.addHandler(streamhandler)
    logger.info(opt)

    sw = SummaryWriter(logdir=opt.save_dir, flush_secs=5, verbose=False)

    if opt.kvstore is not None:
        kv = mx.kvstore.create(opt.kvstore)
        logger.info(
            'Distributed training with %d workers and current rank is %d' %
            (kv.num_workers, kv.rank))
    if opt.use_amp:
        amp.init()

    batch_size = opt.batch_size
    classes = opt.num_classes

    num_gpus = opt.num_gpus
    batch_size *= max(1, num_gpus)
    logger.info('Total batch size is set to %d on %d GPUs' %
                (batch_size, num_gpus))
    context = [mx.gpu(i)
               for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
    num_workers = opt.num_workers

    lr_decay = opt.lr_decay
    lr_decay_period = opt.lr_decay_period
    if opt.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, opt.num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')]
    lr_decay_epoch = [e - opt.warmup_epochs for e in lr_decay_epoch]

    optimizer = 'sgd'
    if opt.clip_grad > 0:
        optimizer_params = {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum,
            'clip_gradient': opt.clip_grad
        }
    else:
        optimizer_params = {
            'learning_rate': opt.lr,
            'wd': opt.wd,
            'momentum': opt.momentum
        }

    if opt.dtype != 'float32':
        optimizer_params['multi_precision'] = True

    model_name = opt.model
    net = get_model(name=model_name,
                    nclass=classes,
                    pretrained=opt.use_pretrained,
                    use_tsn=opt.use_tsn,
                    num_segments=opt.num_segments,
                    partial_bn=opt.partial_bn)
    net.cast(opt.dtype)
    net.collect_params().reset_ctx(context)
    logger.info(net)

    if opt.resume_params is not '':
        net.load_parameters(opt.resume_params, ctx=context)

    if opt.kvstore is not None:
        train_data, val_data, batch_fn = get_data_loader(
            opt, batch_size, num_workers, logger, kv)
    else:
        train_data, val_data, batch_fn = get_data_loader(
            opt, batch_size, num_workers, logger)

    num_batches = len(train_data)
    lr_scheduler = LRSequential([
        LRScheduler('linear',
                    base_lr=0,
                    target_lr=opt.lr,
                    nepochs=opt.warmup_epochs,
                    iters_per_epoch=num_batches),
        LRScheduler(opt.lr_mode,
                    base_lr=opt.lr,
                    target_lr=0,
                    nepochs=opt.num_epochs - opt.warmup_epochs,
                    iters_per_epoch=num_batches,
                    step_epoch=lr_decay_epoch,
                    step_factor=lr_decay,
                    power=2)
    ])
    optimizer_params['lr_scheduler'] = lr_scheduler

    train_metric = mx.metric.Accuracy()
    acc_top1 = mx.metric.Accuracy()
    acc_top5 = mx.metric.TopKAccuracy(5)

    def test(ctx, val_data, kvstore=None):
        acc_top1.reset()
        acc_top5.reset()
        L = gluon.loss.SoftmaxCrossEntropyLoss()
        num_test_iter = len(val_data)
        val_loss_epoch = 0
        for i, batch in enumerate(val_data):
            data, label = batch_fn(batch, ctx)
            outputs = []
            for _, X in enumerate(data):
                X = X.reshape((-1, ) + X.shape[2:])
                pred = net(X.astype(opt.dtype, copy=False))
                outputs.append(pred)

            loss = [
                L(yhat, y.astype(opt.dtype, copy=False))
                for yhat, y in zip(outputs, label)
            ]

            acc_top1.update(label, outputs)
            acc_top5.update(label, outputs)

            val_loss_epoch += sum([l.mean().asscalar()
                                   for l in loss]) / len(loss)

            if opt.log_interval and not (i + 1) % opt.log_interval:
                logger.info('Batch [%04d]/[%04d]: evaluated' %
                            (i, num_test_iter))

        _, top1 = acc_top1.get()
        _, top5 = acc_top5.get()
        val_loss = val_loss_epoch / num_test_iter

        if kvstore is not None:
            top1_nd = nd.zeros(1)
            top5_nd = nd.zeros(1)
            val_loss_nd = nd.zeros(1)
            kvstore.push(111111, nd.array(np.array([top1])))
            kvstore.pull(111111, out=top1_nd)
            kvstore.push(555555, nd.array(np.array([top5])))
            kvstore.pull(555555, out=top5_nd)
            kvstore.push(999999, nd.array(np.array([val_loss])))
            kvstore.pull(999999, out=val_loss_nd)
            top1 = top1_nd.asnumpy() / kvstore.num_workers
            top5 = top5_nd.asnumpy() / kvstore.num_workers
            val_loss = val_loss_nd.asnumpy() / kvstore.num_workers

        return (top1, top5, val_loss)

    def train(ctx):
        if isinstance(ctx, mx.Context):
            ctx = [ctx]

        if opt.no_wd:
            for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
                v.wd_mult = 0.0

        if opt.partial_bn:
            train_patterns = None
            if 'inceptionv3' in opt.model:
                train_patterns = '.*weight|.*bias|inception30_batchnorm0_gamma|inception30_batchnorm0_beta|inception30_batchnorm0_running_mean|inception30_batchnorm0_running_var'
            else:
                logger.info(
                    'Current model does not support partial batch normalization.'
                )

            if opt.kvstore is not None:
                trainer = gluon.Trainer(net.collect_params(train_patterns),
                                        optimizer,
                                        optimizer_params,
                                        kvstore=kv,
                                        update_on_kvstore=False)
            else:
                trainer = gluon.Trainer(net.collect_params(train_patterns),
                                        optimizer,
                                        optimizer_params,
                                        update_on_kvstore=False)
        else:
            if opt.kvstore is not None:
                trainer = gluon.Trainer(net.collect_params(),
                                        optimizer,
                                        optimizer_params,
                                        kvstore=kv,
                                        update_on_kvstore=False)
            else:
                trainer = gluon.Trainer(net.collect_params(),
                                        optimizer,
                                        optimizer_params,
                                        update_on_kvstore=False)

        if opt.accumulate > 1:
            params = [
                p for p in net.collect_params().values()
                if p.grad_req != 'null'
            ]
            for p in params:
                p.grad_req = 'add'

        if opt.resume_states is not '':
            trainer.load_states(opt.resume_states)

        if opt.use_amp:
            amp.init_trainer(trainer)

        L = gluon.loss.SoftmaxCrossEntropyLoss()

        best_val_score = 0
        lr_decay_count = 0

        for epoch in range(opt.resume_epoch, opt.num_epochs):
            tic = time.time()
            train_metric.reset()
            btic = time.time()
            num_train_iter = len(train_data)
            train_loss_epoch = 0
            train_loss_iter = 0

            for i, batch in enumerate(train_data):
                data, label = batch_fn(batch, ctx)

                with ag.record():
                    outputs = []
                    for _, X in enumerate(data):
                        X = X.reshape((-1, ) + X.shape[2:])
                        pred = net(X.astype(opt.dtype, copy=False))
                        outputs.append(pred)
                    loss = [
                        L(yhat, y.astype(opt.dtype, copy=False))
                        for yhat, y in zip(outputs, label)
                    ]

                    if opt.use_amp:
                        with amp.scale_loss(loss, trainer) as scaled_loss:
                            ag.backward(scaled_loss)
                    else:
                        ag.backward(loss)

                if opt.accumulate > 1 and (i + 1) % opt.accumulate == 0:
                    if opt.kvstore is not None:
                        trainer.step(batch_size * kv.num_workers *
                                     opt.accumulate)
                    else:
                        trainer.step(batch_size * opt.accumulate)
                        net.collect_params().zero_grad()
                else:
                    if opt.kvstore is not None:
                        trainer.step(batch_size * kv.num_workers)
                    else:
                        trainer.step(batch_size)

                train_metric.update(label, outputs)
                train_loss_iter = sum([l.mean().asscalar()
                                       for l in loss]) / len(loss)
                train_loss_epoch += train_loss_iter

                train_metric_name, train_metric_score = train_metric.get()
                sw.add_scalar(tag='train_acc_top1_iter',
                              value=train_metric_score * 100,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='train_loss_iter',
                              value=train_loss_iter,
                              global_step=epoch * num_train_iter + i)
                sw.add_scalar(tag='learning_rate_iter',
                              value=trainer.learning_rate,
                              global_step=epoch * num_train_iter + i)

                if opt.log_interval and not (i + 1) % opt.log_interval:
                    logger.info(
                        'Epoch[%03d] Batch [%04d]/[%04d]\tSpeed: %f samples/sec\t %s=%f\t loss=%f\t lr=%f'
                        % (epoch, i, num_train_iter,
                           batch_size * opt.log_interval /
                           (time.time() - btic), train_metric_name,
                           train_metric_score * 100, train_loss_epoch /
                           (i + 1), trainer.learning_rate))
                    btic = time.time()

            train_metric_name, train_metric_score = train_metric.get()
            throughput = int(batch_size * i / (time.time() - tic))
            mx.ndarray.waitall()

            if opt.kvstore is not None and epoch == opt.resume_epoch:
                kv.init(111111, nd.zeros(1))
                kv.init(555555, nd.zeros(1))
                kv.init(999999, nd.zeros(1))

            if opt.kvstore is not None:
                acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data, kv)
            else:
                acc_top1_val, acc_top5_val, loss_val = test(ctx, val_data)

            logger.info('[Epoch %03d] training: %s=%f\t loss=%f' %
                        (epoch, train_metric_name, train_metric_score * 100,
                         train_loss_epoch / num_train_iter))
            logger.info('[Epoch %03d] speed: %d samples/sec\ttime cost: %f' %
                        (epoch, throughput, time.time() - tic))
            logger.info(
                '[Epoch %03d] validation: acc-top1=%f acc-top5=%f loss=%f' %
                (epoch, acc_top1_val * 100, acc_top5_val * 100, loss_val))

            sw.add_scalar(tag='train_loss_epoch',
                          value=train_loss_epoch / num_train_iter,
                          global_step=epoch)
            sw.add_scalar(tag='val_loss_epoch',
                          value=loss_val,
                          global_step=epoch)
            sw.add_scalar(tag='val_acc_top1_epoch',
                          value=acc_top1_val * 100,
                          global_step=epoch)

            if acc_top1_val > best_val_score:
                best_val_score = acc_top1_val
                net.save_parameters('%s/%.4f-%s-%s-%03d-best.params' %
                                    (opt.save_dir, best_val_score, opt.dataset,
                                     model_name, epoch))
                trainer.save_states('%s/%.4f-%s-%s-%03d-best.states' %
                                    (opt.save_dir, best_val_score, opt.dataset,
                                     model_name, epoch))
            else:
                if opt.save_frequency and opt.save_dir and (
                        epoch + 1) % opt.save_frequency == 0:
                    net.save_parameters(
                        '%s/%s-%s-%03d.params' %
                        (opt.save_dir, opt.dataset, model_name, epoch))
                    trainer.save_states(
                        '%s/%s-%s-%03d.states' %
                        (opt.save_dir, opt.dataset, model_name, epoch))

        # save the last model
        net.save_parameters(
            '%s/%s-%s-%03d.params' %
            (opt.save_dir, opt.dataset, model_name, opt.num_epochs - 1))
        trainer.save_states(
            '%s/%s-%s-%03d.states' %
            (opt.save_dir, opt.dataset, model_name, opt.num_epochs - 1))

    if opt.mode == 'hybrid':
        net.hybridize(static_alloc=True, static_shape=True)

    train(context)
    sw.close()
Ejemplo n.º 30
0
def train():
    """training"""
    image_pool = ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)

    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)

    # define a summary writer that logs data and flushes to the file every 5 seconds
    sw = SummaryWriter(logdir='%s' % dir_out_sw, flush_secs=5, verbose=False)
    global_step = 0

    for epoch in range(epochs):
        if epoch == 0:
            netG.hybridize()
            netD.hybridize()
        #     sw.add_graph(netG)
        #     sw.add_graph(netD)

        tic = time.time()
        btic = time.time()
        train_data.reset()
        val_data.reset()
        iter = 0
        for local_step, batch in enumerate(train_data):
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            tmp = mx.nd.concat(batch.data[0],
                               batch.data[1],
                               batch.data[2],
                               dim=1)
            tmp = augmenter(tmp,
                            patch_size=128,
                            offset=offset,
                            aug_type=1,
                            aug_methods=aug_methods,
                            random_crop=False)
            real_in = tmp[:, :1].as_in_context(ctx)
            real_out = tmp[:, 1:2].as_in_context(ctx)
            m = tmp[:, 2:3].as_in_context(ctx)  # mask

            fake_out = netG(real_in) * m

            # loss weight based on mask, applied on L1 loss
            if no_loss_weights:
                loss_weight = m
            else:
                loss_weight = m.asnumpy()
                loss_weight[loss_weight == 0] = .1
                loss_weight = mx.nd.array(loss_weight, ctx=m.context)

            fake_concat = image_pool.query(nd.concat(real_in, fake_out, dim=1))
            with autograd.record():
                # Train with fake image
                # Use image pooling to utilize history images
                output = netD(fake_concat)
                fake_label = nd.zeros(output.shape, ctx=ctx)
                errD_fake = GAN_loss(output, fake_label)
                metric.update([
                    fake_label,
                ], [
                    output,
                ])

                # Train with real image
                real_concat = nd.concat(real_in, real_out, dim=1)
                output = netD(real_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errD_real = GAN_loss(output, real_label)
                errD = (errD_real + errD_fake) * 0.5
                errD.backward()
                metric.update([
                    real_label,
                ], [
                    output,
                ])

            trainerD.step(batch.data[0].shape[0])

            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
                fake_out = netG(real_in)
                fake_concat = nd.concat(real_in, fake_out, dim=1)
                output = netD(fake_concat)
                real_label = nd.ones(output.shape, ctx=ctx)
                errG = GAN_loss(output, real_label) + loss_2nd(
                    real_out, fake_out, loss_weight) * lambda1
                errG.backward()

            trainerG.step(batch.data[0].shape[0])

            sw.add_scalar(tag='loss',
                          value=('d_loss', errD.mean().asscalar()),
                          global_step=global_step)
            sw.add_scalar(tag='loss',
                          value=('g_loss', errG.mean().asscalar()),
                          global_step=global_step)
            global_step += 1

            if epoch + local_step == 0:
                sw.add_graph((netG))
                img_in_list, img_out_list, m_val = val_data.next().data
                m_val = m_val.as_in_context(ctx)
                sw.add_image('first_minibatch_train_real', norm3(real_out))
                sw.add_image('first_minibatch_val_real',
                             norm3(img_out_list.as_in_context(ctx)))
                netG.export('%snetG' % dir_out_checkpoints)
            if local_step == 0:
                # Log the first batch of images of each epoch (training)
                sw.add_image('first_minibatch_train_fake',
                             norm3(fake_out * m) * m, epoch)
                sw.add_image(
                    'first_minibatch_val_fake',
                    norm3(netG(img_in_list.as_in_context(ctx)) * m_val) *
                    m_val, epoch)
                # norm3(netG(img_in_list.as_in_context(ctx)) * m_val.as_in_context(ctx)), epoch)

            if (iter + 1) % 10 == 0:
                name, acc = metric.get()

                logging.info('speed: {} samples/s'.format(
                    batch_size / (time.time() - btic)))
                logging.info(
                    'discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                    % (nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc,
                       iter, epoch))

            iter += 1
            btic = time.time()

        sw.add_scalar(tag='binary_training_acc',
                      value=('acc', acc),
                      global_step=epoch)

        name, acc = metric.get()
        metric.reset()

        fake_val = netG(val_data.data[0][1].as_in_context(ctx))
        loss_val = loss_2nd(val_data.data[1][1].as_in_context(ctx), fake_val,
                            val_data.data[2][1].as_in_context(ctx)) * lambda1
        sw.add_scalar(tag='loss_val',
                      value=('g_loss', loss_val.mean().asscalar()),
                      global_step=epoch)

        if (epoch % check_point_interval == 0) | (epoch == epochs - 1):
            netD.save_params('%snetD-%04d' % (dir_out_checkpoints, epoch))
            netG.save_params('%snetG-%04d' % (dir_out_checkpoints, epoch))

        logging.info('\nbinary training acc at epoch %d: %s=%f' %
                     (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))

    sw.export_scalars('scalar_dict.json')
    sw.close()