Ejemplo n.º 1
0
def nd_forward_backward_and_profile(op, runs, *args, **kwargs):
    """Helper function to run a given NDArray operator (op) for 'runs' number of times with
    given args and kwargs. Executes both forward and backward pass.

    NOTE: This is a sync call and waits for all the operations execution to complete.

    Parameters
    ----------
    op: Str
        NDArray operator (Function reference) to execute. Example: mx.nd.add
    runs: int
        Number of times to execute the operation
    args:
        Arguments for the NDArray operator (op) being executed.
    kwargs:
        Key value arguments for the NDArray operator (op) being executed.

    Returns
    -------
    any results from NDArray operation execution

    """
    for _ in range(runs):
        with mx.autograd.record():
            if not isinstance(args[0], nd.NDArray):
                res = op(**kwargs)
            else:
                res = op(*args, **kwargs)
        res.backward()
        nd.waitall()
    return res
Ejemplo n.º 2
0
def nd_forward_and_profile(op, runs, *args, **kwargs):
    """Helper function to run a given NDArray operator (op) for 'runs' number of times with
    given args and kwargs. Executes ONLY forward pass.

    NOTE: This is a sync call and waits for all the operations execution to complete.

    Parameters
    ----------
    op: Str
        NDArray operator (Function reference) to execute. Example: mx.nd.add
    runs: int
        Number of time to execute the operation
    args:
        Arguments for the NDArray operator (op) being executed.
    kwargs:
        Key value arguments for the NDArray operator (op) being executed.

    Returns
    -------
    any results from NDArray operation execution
    """
    for _ in range(runs):
        res = op(*args, **kwargs)
        nd.waitall()
    return res
Ejemplo n.º 3
0
def predict():
    img = pre_process_image(
        '/incubator-mxnet/scala-package/examples/scripts/infer/images/dog.jpg')
    # compute the predict probabilities
    import time
    # print img.shape
    data_iter = mx.io.NDArrayIter([img], None, 16)
    start = time.time()

    op = mod.predict(data_iter)
    #mod.forward(Batch([img]))
    nd.waitall()
    #    print (type(op[0]))
    end = time.time()

    #   prob = op[0]
    #    print (op[0])
    #op[0].copyto(prob)

    #prob = prob.asnumpy()
    # print (mod.get_outputs()[0].shape)
    # print len(mod.get_outputs())
    #prob = mod.get_outputs()[0]
    #shape = mod.get_outputs()[0].shape
    print(end - start)

    #prob = np.squeeze(prob)
    #a = np.argsort(prob)[::-1]
    #for i in a[0:5]:
    #   print('probability=%f, class=%s' %(prob[i], labels[i]))
    # print the top-5
    #    prob = np.squeeze(prob)
    #   a = np.argsort(prob)[::-1]
    # print (len(prob))
    return end - start
Ejemplo n.º 4
0
def test_nd_op(ctx=mx.gpu()):
    nd.waitall()
    np_arr = np.zeros((3, 1, 5, 5))
    nd_arr = nd.array(np_arr, ctx)
    print(nd_arr)
    time.sleep(5)
    print('success')
Ejemplo n.º 5
0
    def workflow_inference(self, instream, shape):
        while (True):
            # Capture frame-by-frame
            ret, frame = cap.read()

            frame = cv2.resize(frame, (1280, 720))

            # Our operations on the frame come here
            #gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

            out = self._retina_forward(frame)

            try:
                self.write_queue((frame, out))
            except:
                waitall()
                print('Frame queue full', file=sys.stderr)

            # Display the resulting frame
            #cv2.imshow('frame',gray)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        # When everything done, release the capture
        cap.release()
        cv2.destroyAllWindows()
Ejemplo n.º 6
0
    def _impl(op, params, graph, **kwargs):
        deps = kwargs['deps']
        name, op_name = op.attr('name'), op.attr('op_name')
        childs, attr = sym_iter(op.get_children()), op.list_attr()

        if op_name == 'null':
            start_time = None
            out = data if is_inputs(op, params) else params[name]
        elif childs is None:
            start_time= time.time()
            out = get_nd_op(op_name)(**attr)
            if gpu_flag:
                nd.waitall()
            end_time = time.time()
        else:
            cinfos = [(c.attr('name'), get_entry_id(c)) for c in childs]
            nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos]
            start_time = time.time()
            out = get_nd_op(op_name)(*nd_inputs, **attr)
            if gpu_flag:
                nd.waitall()
            end_time = time.time()
            for n, _ in cinfos:
                assert n in deps
                deps[n].remove(name)
                if len(deps[n]) == 0:
                    del out_cache[n]
        if start_time is not None:
            if op_name not in times:
                times[op_name] = {}
            times[op_name][name] = end_time - start_time
        out = [out] if len(op) == 1 else out
        out_cache[name] = [o.as_in_context(ctx) for o in out]
 def hybrid_forward(self, F, x, *args, **kwargs):
     for layer in self.net:
         tic = time.time()
         x = layer(x)
         nd.waitall()
         time0 = time.time()-tic
         time0 -= 0
     return x
def test_model(model, inpt, amount, wait=True):
    tic = time.time()
    for i in range(amount):
        if amount >1 and i == 1: tic = time.time()
        out = model(inpt)
        if wait: nd.waitall()
    amount = amount - 1 if amount > 1 else amount
    time_use = (time.time() - tic)/amount
    return time_use, out
Ejemplo n.º 9
0
    def do_a_train_step(self, stp=None):

        # NOTE: wait the weights
        ndarray.waitall()

        if not stp:
            stp = self.batch_size
        self.trainer.step(stp)

        # TODO: zero grad
        mxprms.params_zero_grad(self.net.collect_params())
Ejemplo n.º 10
0
    def workflow_inference(self, instream, shape):
        for source in instream:
            # st = time.perf_counter()

            frame = frombuffer(source, dtype=uint8).reshape(shape)
            out = self._retina_forward(frame)

            try:
                self.write_queue((frame, out))
            except:
                waitall()
                print('Frame queue full', file=sys.stderr)
Ejemplo n.º 11
0
def trainNet(net, trainer, train_data, loss, train_metric, epoch, config,
             logger, ctx):
    if not logger:
        assert False, 'require a logger'

    train_data.reset()  # reset and re-shuffle
    if train_metric:
        train_metric.reset()

    trainloss, n = [0] * len(ctx), 0

    for batch_i, batch in enumerate(train_data):
        data_list = gluon.utils.split_and_load(batch.data[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
        label_list = gluon.utils.split_and_load(batch.label[0],
                                                ctx_list=ctx,
                                                batch_axis=0)

        Ls = []
        output_list = []
        with autograd.record():
            for x, y in zip(data_list, label_list):
                preds = net(x)
                L = loss(preds, y)
                Ls.append(L)
                output_list.append(preds)

            for L in Ls:
                L.backward()

        trainer.step(batch.data[0].shape[0])

        # Number
        n += batch.data[0].shape[0]

        # Loss
        for i in range(len(trainloss)):
            trainloss[i] += Ls[i]

    nd.waitall()
    trainloss = sum([item.sum().asscalar() for item in trainloss])
    logger.info("TRAIN - Epoch:%d LR:%.2e Loss:%.2e" %
                (epoch + 1, trainer.learning_rate, trainloss / n))

    # save model
    if ((epoch + 1) % (config.TRAIN.end_epoch / 5) == 0 or epoch == 0):
        saveModel(net, logger, config, isCKP=True, epoch=epoch + 1)
    if (epoch + 1 == config.TRAIN.end_epoch):
        saveModel(net, logger, config, isCKP=False, epoch=epoch + 1)
def nd_forward_and_time(F, runs, *args, **kwargs):
    """Helper function to run a given NDArray operator (F) for 'runs' number of times with
    given args and kwargs. Executes ONLY forward pass.

    NOTE: This is a sync call and waits for all the operations execution to complete.

    :param F: NDArray operator (Function feference) to execute. Example: mx.nd.add
    :param runs: Number of time to execute the operation
    :param args: Arguments for the NDArray operator (F) being executed.
    :param kwargs: Key value arguments for the NDArray operator (F) being executed.
    :return: Tuple(Total execution time in seconds, any results from NDArray operation execution)
    """
    for _ in range(runs):
        F(*args, **kwargs)
        nd.waitall()
def block_forward_backward_and_time(*args, block, runs, **kwargs):
    """Helper function to run a given Block (block) for 'runs' number of times with
    given args and kwargs. Executes both forward and backward pass.

    NOTE: This is a sync call and waits for all the operations execution to complete.

    :param block: Gluon block to execute. Example: an instance of gluon.nn.Dense(...)
    :param runs: Number of times to execute the block operation
    :param args: Arguments for the block being executed.
    :param kwargs: Key value arguments for the block being executed.
    :return: Tuple of (Total execution time in seconds, any results from block execution)
    """
    for _ in range(runs):
        with mx.autograd.record():
            res = block.forward(*args, **kwargs)
        res.backward()
        nd.waitall()
    def _run_forward_backward_benchmark(self, runs, x):
        for _ in range(runs):
            with mx.autograd.record():
                # Forward
                res1 = x + 1
                res2 = res1 + 1
                res3 = res2 + 1
                res4 = nd.Custom(res3,
                                 name="customaddone",
                                 op_type="CustomAddOne")
                res5 = res4 + 1
                res6 = res5 + 1
                res7 = res6 + 1

            # Backward
            res7.backward()
            nd.waitall()
def block_forward_and_time(*args, block, runs, **kwargs):
    """Helper function to run a given Block (block) for 'runs' number of times with
    given args and kwargs. Executes forward pass only.

    NOTE: This is a sync call and waits for all the operations execution to complete.

    :param block: Gluon block to execute. Example: an instance of gluon.nn.Dense(...)
    :param runs: Number of times to execute the block operation
    :param args: Arguments for the block being executed.
    :param kwargs: Key value arguments for the block being executed.
    :return: Tuple of (Total execution time in seconds, any results from block execution)
    """

    for _ in range(runs):
        # Imperative Mode. This is block forward function
        block.hybrid_forward(F=nd, *args, **kwargs)
        nd.waitall()
Ejemplo n.º 16
0
def nd_forward_backward_and_profile(op, runs, **kwargs):
    """Helper function to run a given NDArray operator (op) for 'runs' number of times with
    given args and kwargs. Executes both forward and backward pass.

    NOTE: This is a sync call and waits for all the operations execution to complete.

    Parameters
    ----------
    op: Str
        NDArray operator (Function reference) to execute. Example: mx.nd.add
    runs: int
        Number of times to execute the operation
    kwargs:
        Key value arguments for the NDArray operator (op) being executed.

    Returns
    -------
    any results from NDArray operation execution

    """
    for _ in range(runs):
        with mx.autograd.record():
            args = []
            # need to create a new dictionary because can't update dict while iterating
            kwargs_new = dict()
            for key in kwargs:
                # separate positional args from key-worded args
                if key.startswith("args"):
                    args.append(kwargs[key])
                else:
                    kwargs_new[key] = kwargs[key]
            # check for positional args
            if len(args):
                res = op(*args, **kwargs_new)
            else:
                res = op(**kwargs_new)
        res.backward()
        nd.waitall()
    return res
Ejemplo n.º 17
0
def validNet(net, valid_data, loss, eval_metric, epoch, config, logger, ctx):
    if not logger:
        assert False, 'require a logger'

    valid_data.reset()
    if eval_metric:
        eval_metric.reset()

    validloss, n = [0] * len(ctx), 0
    for batch_i, batch in enumerate(valid_data):
        data_list = gluon.utils.split_and_load(batch.data[0],
                                               ctx_list=ctx,
                                               batch_axis=0)
        label_list = gluon.utils.split_and_load(batch.label[0],
                                                ctx_list=ctx,
                                                batch_axis=0)

        Ls = []
        output_list = []
        for x, y in zip(data_list, label_list):
            preds = net(x)
            L = loss(preds, y)
            output_list.append(preds)
            Ls.append(L)

        if config.TRAIN.UseMetric:
            for lb, pd in zip(label_list, output_list):
                eval_metric.update(lb, pd)

        for i in range(len(validloss)):
            validloss[i] += Ls[i]
        n += batch.data[0].shape[0]

    nd.waitall()
    validloss = sum([item.sum().asscalar() for item in validloss])
    MPJPE = eval_metric.get()[-1].sum(axis=0) / 17
    logger.info("VALID - Epoch:%d Loss:%.3e MPJPE:%.1f" %
                (epoch + 1, validloss / n, MPJPE))
Ejemplo n.º 18
0
def _chris_update_params_on_kvstore(param_arrays, grad_arrays, kvstore,
                                    param_names):
    """Perform update of param_arrays from grad_arrays on kvstore."""

    for index, pair in enumerate(zip(param_arrays, grad_arrays)):
        arg_list, grad_list = pair
        if grad_list[0] is None:
            continue
        name = param_names[index]
        # push gradient, priority is negative index
        kvstore.push(name, grad_list, priority=-index)
    if os.getenv("GLOBAL_BARRIER", 0) == 1:
        ndarray.waitall()
    # if os.getenv('PULL_SLEEP_TIME') is not None:
    #     delay = float(os.getenv('PULL_SLEEP_TIME'))
    #     time.sleep(delay)
    # self.logger.info("before pull in  _chris_update_params_on_kvstore, time is:",time.time())
    for index, pair in enumerate(zip(param_arrays, grad_arrays)):
        arg_list, grad_list = pair
        if grad_list[0] is None:
            continue
        name = param_names[index]
        # pull back the weights
        kvstore.pull(name, arg_list, priority=-index)
Ejemplo n.º 19
0
def train(
    args,
    model,
    train_sampler,
    valid_samplers=None,
    rank=0,
    rel_parts=None,
    barrier=None,
):
    assert args.num_proc <= 1, "MXNet KGE does not support multi-process now"
    assert (args.rel_part == False
            ), "No need for relation partition in single process for MXNet KGE"
    logs = []

    for arg in vars(args):
        logging.info("{:20}:{}".format(arg, getattr(args, arg)))

    if len(args.gpu) > 0:
        gpu_id = (args.gpu[rank % len(args.gpu)]
                  if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0])
    else:
        gpu_id = -1

    if args.strict_rel_part:
        model.prepare_relation(mx.gpu(gpu_id))

    if mxprofiler:
        from mxnet import profiler

        profiler.set_config(
            profile_all=True,
            aggregate_stats=True,
            continuous_dump=True,
            filename="profile_output.json",
        )
    start = time.time()
    for step in range(0, args.max_step):
        pos_g, neg_g = next(train_sampler)
        args.step = step
        if step == 1 and mxprofiler:
            profiler.set_state("run")
        with mx.autograd.record():
            loss, log = model.forward(pos_g, neg_g, gpu_id)
        loss.backward()
        logs.append(log)
        model.update(gpu_id)

        if step % args.log_interval == 0:
            for k in logs[0].keys():
                v = sum(l[k] for l in logs) / len(logs)
                print("[Train]({}/{}) average {}: {}".format(
                    step, args.max_step, k, v))
            logs = []
            print(time.time() - start)
            start = time.time()

        if (args.valid and step % args.eval_interval == 0 and step > 1
                and valid_samplers is not None):
            start = time.time()
            test(args, model, valid_samplers, mode="Valid")
            print("test:", time.time() - start)
    if args.strict_rel_part:
        model.writeback_relation(rank, rel_parts)
    if mxprofiler:
        nd.waitall()
        profiler.set_state("stop")
        profiler.dump()
        print(profiler.dumps())
    # clear cache
    logs = []
Ejemplo n.º 20
0
def validNet(net, valid_data, loss, eval_metric, epoch, config, logger, ctx):
    if not logger:
        assert False, 'require a logger'

    valid_data.reset()
    if eval_metric:
        eval_metric.reset()

    w = config.TRAIN.w
    batchsize = config.TRAIN.batchsize
    UseMetric = config.TRAIN.UseMetric
    seqLength = config.DATASET.seqLength
    nJoints = config.NETWORK.nJoints

    loss1, loss2, n1, n2 = [0] * len(ctx), [0] * len(ctx), 0.000001, 0.000001
    for batch_i, batch in enumerate(valid_data):
        data_list = gluon.utils.split_and_load(batch.data[0],
                                               ctx_list=ctx,
                                               batch_axis=1)
        label_list = gluon.utils.split_and_load(batch.label[0],
                                                ctx_list=ctx,
                                                batch_axis=1)

        Ls1, Ls2, output_list = [], [], []
        # forward
        for data, label, cx in zip(data_list, label_list, ctx):
            initial_state = [
                nd.zeros(shape=(batchsize, config.NETWORK.hidden_dim), ctx=cx)
                for _ in range(2)
            ]
            start_token = nd.ones(shape=(batchsize, 3 * nJoints), ctx=cx)
            preds = net(data, initial_state, start_token)
            output_list.append(preds)  # pred=[seqLength, 64x48]

            L1, L2 = 0, 0
            for pd, lb in zip(preds, label):
                L1 = L1 + loss(pd, lb)
            if seqLength > 1:
                for i in range(1, seqLength):
                    deltaP = preds[i] - preds[i - 1]
                    deltaG = label[i] - label[i - 1]
                    L2 = L2 + loss(deltaP, deltaG)
            Ls1.append(L1)
            Ls2.append(L2) if seqLength > 1 else Ls2.append(nd.zeros(1))

        # number
        n1 = n1 + len(ctx) * batchsize * seqLength
        n2 = n2 + len(ctx) * batchsize * (seqLength - 1)

        # loss
        for i in range(len(loss1)):
            loss1[i] += Ls1[i]
            loss2[i] += Ls2[i]

        # metric, save time
        if UseMetric:
            for pred_batch, label_batch in zip(
                    output_list, label_list):  # for each timestamp
                for t_pred, t_label in zip(pred_batch, label_batch):
                    eval_metric.update(t_label, t_pred)
    nd.waitall()
    loss1 = sum([item.sum().asscalar() for item in loss1])
    loss2 = sum([item.sum().asscalar() for item in loss2])
    validloss = loss1 / n1 + w * loss2 / n2
    MPJPE = eval_metric.get()[-1].sum(axis=0) / 17 if UseMetric else 0

    logger.info(
        "VALID - Epoch:%2d Loss1:%.2e Loss2(%2d):%.2e TotalLoss:%.2e MPJPE:%.1f"
        % (epoch + 1, loss1 / n1, w, loss2 / n2, validloss, MPJPE))
Ejemplo n.º 21
0
import mxnet as mx
import mxnet.ndarray as nd
import time
N = 10
a = mx.random.normal(0, 1, (4, 1024, 1024), ctx=mx.gpu(0))
b = mx.nd.empty((4, 1024, 1024), ctx=mx.cpu())
c = mx.nd.empty((4, 1024, 1024), ctx=mx.gpu())
nd.waitall()
a.asnumpy()
start = time.time()
for i in range(N):
    b[:] = a
    nd.waitall()
avg_time = (time.time() - start)/N
print('GPU->CPU, GB/s: %g' %(4 * a.size * 1E-9 / avg_time))

start = time.time()
for i in range(N):
    a[:] = b
    nd.waitall()
avg_time = (time.time() - start)/N
print('CPU->GPU, GB/s: %g' %(4 * a.size * 1E-9 / avg_time))

start = time.time()
for i in range(N):
    c[:] = a
    nd.waitall()
avg_time = (time.time() - start)/N
print('GPU->GPU, GB/s: %g' %(4 * a.size * 1E-9 / avg_time))
Ejemplo n.º 22
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)
        ####chris_arg
        if int(os.getenv("TASK_LIMIT",
                         0)) != 0:  #为0时不分task限制,为1时分task但是每轮更新,为2时分task并但固定
            get_task_cmd = "sh /home/ubuntu/tc.sh -l 1"
        else:
            self.logger.info("no_task_bandwidth_limit")
            get_task_cmd = "sh /home/ubuntu/tc.sh -l 0"
        os.system(get_task_cmd)
        delay_time = float(os.getenv("DELAY_TIME", 0.8))
        ps_upload_bandwidth_part1 = int(os.getenv("PS_UPLOAD_BANDWIDTH1",
                                                  2000))
        worker_upload_bandwidth_part1 = int(
            os.getenv("WORKER_UPLOAD_BANDWIDTH1", 2000))
        ps_upload_bandwidth_part2 = int(os.getenv("PS_UPLOAD_BANDWIDTH2",
                                                  2000))
        worker_upload_bandwidth_part2 = int(
            os.getenv("WORKER_UPLOAD_BANDWIDTH2", 2000))
        tc_command = "sudo tc class change dev {} parent 1: classid 1:3 htb rate {}mbit ceil {}mbit  && sudo tc class change dev {} parent 1: classid 1:4 htb rate {}mbit ceil {}mbit"
        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward(data_batch, is_train=True)
                if int(os.getenv("TASK_LIMIT", 0)) == 1:
                    ##first part bandwidth allocation
                    ndarray.waitall()
                    # self.logger.info("change bandwidth part1:, "+str(time.time()))
                    x = str(ps_upload_bandwidth_part1)
                    y = str(worker_upload_bandwidth_part1)
                    cmd_up = tc_command.format("ens3", x, x, "ens3", y, y)
                    cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x)
                    os.system(cmd_up)
                    # os.system(cmd_down)
                # self.logger.info("after forward, "+str(time.time()))
                self.backward()
                # self.logger.info("before update: "+str(time.time()))
                self.update()  #异步执行的
                if int(os.getenv("TASK_LIMIT", 0)) == 1:
                    x = str(ps_upload_bandwidth_part2)
                    y = str(worker_upload_bandwidth_part2)
                    cmd_up = tc_command.format("ens3", x, x, "ens3", y, y)
                    cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x)
                    time.sleep(delay_time)
                    ##second part bandwidth allocation
                    # self.logger.info("change bandwidth part2:, "+str(time.time()))
                    os.system(cmd_up)
                    # os.system(cmd_down)
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Ejemplo n.º 23
0
def trainNet(net, trainer, train_data, loss, train_metric, epoch, config,
             logger, ctx):
    if not logger:
        assert False, 'require a logger'

    train_data.reset()  # reset and re-shuffle
    if train_metric:
        train_metric.reset()

    trainer.set_learning_rate(config.TRAIN.lr *
                              pow(config.TRAIN.decay_rate, epoch))

    w = config.TRAIN.w
    batchsize = config.TRAIN.batchsize
    UseMetric = config.TRAIN.UseMetric
    seqLength = config.DATASET.seqLength
    nJoints = config.NETWORK.nJoints

    loss1, loss2, n1, n2 = [0] * len(ctx), [0] * len(ctx), 0.000001, 0.000001
    RecordTime = {'load': 0, 'forward': 0, 'backward': 0, 'post': 0}

    for batch_i, batch in enumerate(train_data):
        beginT = time.time()
        data_list = gluon.utils.split_and_load(batch.data[0],
                                               ctx_list=ctx,
                                               batch_axis=1)
        label_list = gluon.utils.split_and_load(
            batch.label[0], ctx_list=ctx,
            batch_axis=1)  # [[seqLength x 64 x 48] , 4]
        RecordTime['load'] += time.time() - beginT

        # forward
        beginT = time.time()
        Ls, Ls1, Ls2, output_list = [], [], [], []
        with autograd.record():
            for data, label, cx in zip(data_list, label_list, ctx):
                initial_state = [
                    nd.zeros(shape=(batchsize, config.NETWORK.hidden_dim),
                             ctx=cx) for _ in range(2)
                ]
                start_token = nd.ones(shape=(batchsize, 3 * nJoints), ctx=cx)
                preds = net(data, initial_state, start_token)
                output_list.append(preds)  # pred=[5, 64x48]

                L1, L2 = 0, 0
                for pd, lb in zip(preds, label):
                    L1 = L1 + loss(pd, lb)
                if seqLength > 1:
                    for i in range(1, seqLength):
                        deltaP = preds[i] - preds[i - 1]
                        deltaG = label[i] - label[i - 1]
                        L2 = L2 + loss(deltaP, deltaG)
                Ls1.append(L1)
                Ls2.append(L2) if seqLength > 1 else Ls2.append(nd.zeros(1))
                Ls.append(L1 + w * L2)
        RecordTime['forward'] += time.time() - beginT

        # backward
        beginT = time.time()
        for L in Ls:
            L.backward()
        trainer.step(len(ctx) * batchsize)
        RecordTime['backward'] += time.time() - beginT

        beginT = time.time()
        # number
        n1 = n1 + len(ctx) * batchsize * seqLength
        n2 = n2 + len(ctx) * batchsize * (seqLength - 1)

        # loss
        for i in range(len(loss1)):
            loss1[i] += Ls1[i]
            loss2[i] += Ls2[i]

        # metric, save time
        if UseMetric:
            for pred_batch, label_batch in zip(
                    output_list, label_list):  # for each timestamp
                for t_pred, t_label in zip(pred_batch, label_batch):
                    train_metric.update(t_label, t_pred)
        RecordTime['post'] += time.time() - beginT

    totalT = nd.array([RecordTime[k] for k in RecordTime]).sum().asscalar()
    for key in RecordTime:
        print("%-s: %.1fs %.1f%% " %
              (key, RecordTime[key], RecordTime[key] / totalT * 100),
              end=" ")
    print(" ")

    nd.waitall()
    loss1 = sum([item.sum().asscalar() for item in loss1])
    loss2 = sum([item.sum().asscalar() for item in loss2])
    TotalLoss = loss1 / n1 + w * loss2 / n2
    MPJPE = train_metric.get()[-1].sum(
        axis=0).asscalar() / 17 if UseMetric else 0

    logger.info(
        "TRAIN - Epoch:%2d LR:%.2e Loss1:%.2e Loss2(%2d):%.2e TotalLoss:%.2e MPJPE:%.1f"
        % (epoch + 1, trainer.learning_rate, loss1 / n1, w, loss2 / n2,
           TotalLoss, MPJPE))

    if ((epoch + 1) % (config.end_epoch / 4) == 0
            or epoch == 0):  # save checkpoint
        saveModel(net, logger, config, isCKP=True, epoch=epoch + 1)
    if (epoch + 1 == config.end_epoch):  # save final model
        saveModel(net, logger, config, isCKP=False, epoch=epoch + 1)
def test1():
    ctx = mx.gpu()
    mx.random.seed(128)
    Cout = 256
    kernel_max = 3
    N, Cin, Height, Width = (128, 256, 112, 112)
    # mask = get_random_mask(Cout,Cin,(3,3),kernel_max).as_in_context(ctx)
    mask = nd.array([[[0, 1, 2]] * Cin] * Cout, dtype='float32', ctx=ctx)
    mask_ = nd.array([[0, 1, 2]] * Cin, dtype='float32', ctx=ctx)

    weight = nd.random.normal(0, 1e-2, (Cout, Cin, kernel_max))
    amount = 10
    inpt = nd.random.uniform(0, 1, (N, Cin, Height, Width), ctx=ctx)

    # custom original convolution
    net0 = MyConv(Cout, Cin, (1, kernel_max), kernel_max, mask,
                  weight_initializer=mx.init.Constant(weight.expand_dims(2)), padding=(0, 1))
    net0.initialize(ctx=ctx)
    tic = time.time()
    for _ in range(amount):
        out0 = net0(inpt)
        nd.waitall()
    # out0 = net0(inpt)
    timeuse0 = (time.time() - tic) / amount
    # out0 = out0[:, :, :-2, 1:-1]

    # FLK v1
    # net = FLKConv(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask, weight_initializer=mx.init.Constant(weight))
    # net.initialize(ctx=ctx)
    # tic = time.time()
    # for _ in range(amount):
    #     out1 = net(inpt)
    #     nd.waitall()
    # timeuse1 = (time.time() - tic) / amount

    # net2 = FLKConv_v2(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask,
    #                   weight_initializer=mx.init.Constant(weight))
    # net2.initialize(ctx=ctx)
    # tic = time.time()
    # for _ in range(amount):
    #     out2 = net2(inpt)
    #     nd.waitall()
    # timeuse2 = (time.time() - tic) / amount
    #
    # net3 = FLKConv_v3(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask,
    #                   weight_initializer=mx.init.Constant(weight))
    # net3.initialize(ctx=ctx)
    # tic = time.time()
    # for _ in range(amount):
    #     out3 = net3(inpt)
    #     nd.waitall()
    # timeuse3 = (time.time() - tic) / amount

    net4 = FLKConv_v4(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask_,
                      weight_initializer=mx.init.Constant(weight))
    net4.initialize(ctx=ctx)
    tic = time.time()
    for _ in range(amount):
        out4 = net4(inpt)
        nd.waitall()
    timeuse4 = (time.time() - tic) / amount


    print('')
Ejemplo n.º 25
0
def train(args):
    np.random.seed(args.seed)
    if args.gpu:
        ctx = [mx.gpu(0)]
    else:
        ctx = [mx.cpu(0)]
    if args.dataset == "Sony":
        out_channels = 12
        scale = 2
    else:
        out_channels = 27
        scale = 3

    # load data
    train_transform = utils.Compose([
        utils.RandomCrop(args.patch_size, scale),
        utils.RandomFlipLeftRight(),
        utils.RandomFlipTopBottom(),
        utils.RandomTranspose(),
        utils.ToTensor(),
    ])
    train_dataset = data.MyDataset(args.dataset,
                                   "train",
                                   transform=train_transform)
    val_transform = utils.Compose([utils.ToTensor()])
    val_dataset = data.MyDataset(args.dataset, "val", transform=val_transform)
    train_loader = gluon.data.DataLoader(train_dataset,
                                         shuffle=True,
                                         batch_size=args.batch_size,
                                         last_batch='rollover')
    val_loader = gluon.data.DataLoader(val_dataset,
                                       batch_size=1,
                                       last_batch='discard')
    unet = net.UNet(out_channels, scale)
    unet.initialize(init=initializer.Xavier(), ctx=ctx)

    # optimizer and loss
    trainer = gluon.Trainer(unet.collect_params(), 'adam',
                            {'learning_rate': args.lr})
    l1_loss = gluon.loss.L1Loss()

    print "Start training now.."
    for i in range(args.epochs):
        total_loss = 0
        count = 0
        profiler.set_state('run')
        for batch_id, (img, gt) in enumerate(train_loader):
            batch_size = img.shape[0]
            count += batch_size
            img_list = gluon.utils.split_and_load(img[0], ctx)
            gt_list = gluon.utils.split_and_load(gt[0], ctx)
            with autograd.record():
                preds = [unet(x) for x in img_list]
                losses = []
                for ii in range(len(preds)):
                    loss = l1_loss(gt_list[ii], preds[ii])
                    losses.append(loss)
            for loss in losses:
                loss.backward()
            total_loss += sum([l.sum().asscalar() for l in losses])
            avg_loss = total_loss / count
            trainer.step(batch_size)
            metric.update(gt_list, preds)
            F.waitall()
            profiler.set_state('stop')
            print profiler.dumps()
            break
            gt_save = gt_list[0]
            output_save = preds[0]

            if (batch_id + 1) % 100 == 0:
                message = "Epoch {}: [{}/{}]: l1_loss: {:.4f}".format(
                    i + 1, count, len(train_dataset), avg_loss)
                print message
        temp = F.concat(gt_save, output_save, dim=3)
        temp = temp.asnumpy().reshape(temp.shape[2], temp.shape[3], 3)
        scipy.misc.toimage(temp * 255,
                           high=255,
                           low=0,
                           cmin=0,
                           cmax=255,
                           mode='RGB').save(args.save_model_dir +
                                            '%04d_%05d_00_train.jpg' %
                                            (i + 1, count))

        # evaluate
        batches = 0
        avg_psnr = 0.
        for img, gt in val_loader:
            batches += 1
            imgs = gluon.utils.split_and_load(img[0], ctx)
            label = gluon.utils.split_and_load(gt[0], ctx)
            outputs = []
            for x in imgs:
                outputs.append(unet(x))
            metric.update(label, outputs)
            avg_psnr += 10 * math.log10(1 / metric.get()[1])
            metric.reset()
        avg_psnr /= batches
        print('Epoch {}: validation avg psnr: {:.3f}'.format(i + 1, avg_psnr))

        # save model
        if (i + 1) % args.save_freq == 0:
            save_model_filename = "Epoch_" + str(i + 1) + ".params"
            save_model_path = os.path.join(args.save_model_dir,
                                           save_model_filename)
            unet.save_params(save_model_path)
            print("\nCheckpoint, trained model saved at", save_model_path)

    # save model
    save_model_filename = "Final_Epoch_" + str(i + 1) + ".params"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    unet.save_params(save_model_path)
    print("\nCheckpoint, trained model saved at", save_model_path)
Ejemplo n.º 26
0
def train(num_gpus=GPU_COUNT,
          batch_size=batch_size,
          lr=0.001,
          train_epoch=train_epoch):
    #训练模型
    train_iter, test_iter = train_loader, val_loader
    ctx = [mx.gpu(i) for i in range(num_gpus)]
    # print('running on:', ctx)
    logging.info('running on:GPU')
    # loss = gluon.loss.SoftmaxCrossEntropyLoss()

    best_acc = 0.
    # loss = L2Softmax(classes = 25,alpha = 10)
    trainer = gluon.Trainer(net.collect_params(), 'Adam',
                            {'learning_rate': lr})

    for epoch in range(train_epoch):
        train_loss = 0.
        start = time.time()
        for i, (X, y) in enumerate(train_iter):
            X = X / 255
            gpu_Xs = gutils.split_and_load(X, ctx, even_split=False)
            gpu_ys = gutils.split_and_load(y, ctx, even_split=False)
            with autograd.record():
                # ls = [loss(net(gpu_X), mx.nd.one_hot(gpu_y,classes))
                #       for gpu_X, gpu_y in zip(gpu_Xs, gpu_ys)]
                ls = [
                    loss(net(gpu_X), gpu_y)
                    for gpu_X, gpu_y in zip(gpu_Xs, gpu_ys)
                ]
            for l in ls:
                l.backward()
            trainer.step(batch_size)
            ls_list = [nd.mean(i).asscalar() for i in ls]
            train_loss += sum(ls_list) / len(ls_list)
            if (i + 1) % 50 == 0:
                #打印训练日志
                logging.info(
                    "epoch[{epoch}]  batch_num[{batch_num}]  epochtrain_loss : {loss}"
                    .format(epoch=epoch + 1,
                            batch_num=i + 1,
                            loss=train_loss / (i + 1)))
                # test_acc = evaluate_accuracy(test_iter, net, ctx)
                # train_time = time.time() - start
                # logging.info('epoch %d, time %.1f sec, test acc %.7f' % (
                #     epoch + 1, train_time, test_acc))
                # net.save_parameters('weights/best_' + save_name + '.params')
        nd.waitall()
        train_time = time.time() - start
        test_acc = evaluate_accuracy(test_iter, net, ctx)
        logging.info('epoch[%d], time %.5f sec, test acc %.7f' %
                     (epoch + 1, train_time, test_acc))
        if test_acc > best_acc:
            best_acc = test_acc
            net.save_parameters('weights/best_' + save_name + '.params')
        if (epoch + 1) % 5 == 0:
            net.save_parameters('weights/gluon_' + save_name + str(epoch + 1) +
                                '.params')
        net.save_parameters('weights/gluon_' + save_name + str(epoch + 1) +
                            '.params')
        net.save_parameters('weights/last_' + save_name + '.params')
Ejemplo n.º 27
0
def main():
    parser = argparse.ArgumentParser(description='Script to test the trained network on a game.')
    parser.add_argument('-r', '--rom', required=False, type=str,
                        default=os.path.join('arena', 'games', 'roms', 'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v', '--visualization', required=False, type=int, default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr', required=False, type=float, default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps', required=False, type=float, default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--clip-gradient', required=False, type=float, default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q', required=False, type=bool, default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd', required=False, type=float, default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument('-c', '--ctx', required=False, type=str, default='gpu',
                        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d', '--dir-path', required=False, type=str, default='',
                        help='Saving directory of model files.')
    parser.add_argument('--start-eps', required=False, type=float, default=1.0,
                        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size', required=False, type=int, default=50000,
                        help='The step that the training starts')
    parser.add_argument('--kvstore-update-period', required=False, type=int, default=1,
                        help='The period that the worker updates the parameters from the sever')
    parser.add_argument('--kv-type', required=False, type=str, default=None,
                        help='type of kvstore, default will not use kvstore, could also be dist_async')
    args, unknown = parser.parse_known_args()
    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        args.dir_path = 'dqn-%s' % rom_name
    ctx = re.findall('([a-z]+)(\d*)', args.ctx)
    ctx = [(device, int(num)) if len(num) >0 else (device, 0) for device, num in ctx]
    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = 1000000
    history_length = 4
    rows = 84
    cols = 84
    q_ctx = mx.Context(*ctx[0])

    game = AtariGame(rom_path=args.rom, resize_mode='scale', replay_start_size=replay_start_size,
                     resized_rows=rows, resized_cols=cols, max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size, display_screen=args.visualization,
                     history_length=history_length)

    ##RUN NATURE
    freeze_interval = 10000
    epoch_num = 200
    steps_per_epoch = 250000
    update_interval = 4
    discount = 0.99

    eps_start = args.start_eps
    eps_min = 0.1
    eps_decay = (eps_start - 0.1) / 1000000
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32
    action_num = len(game.action_set)

    data_shapes = {'data': (minibatch_size, history_length) + (rows, cols),
                   'dqn_action': (minibatch_size,), 'dqn_reward': (minibatch_size,)}
    #optimizer = mx.optimizer.create(name='sgd', learning_rate=args.lr,wd=args.wd)
    optimizer = mx.optimizer.Nop()
    dqn_output_op = DQNOutputNpyOp()
    dqn_sym = dqn_sym_nature(action_num, dqn_output_op)
    qnet = Base(data_shapes=data_shapes, sym=dqn_sym, name='QNet',
                  initializer=DQNInitializer(factor_type="in"),
                  ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)
    # Create kvstore
    testShape = (1,1686180*100)
    testParam = nd.ones(testShape,ctx=q_ctx)
    testGrad = nd.zeros(testShape,ctx=q_ctx)

    # Create kvstore

    if args.kv_type != None:
        kvType = args.kv_type
        kvStore = kvstore.create(kvType)
        #Initialize kvstore
        for idx,v in enumerate(qnet.params.values()):
            kvStore.init(idx,v);
        # Set optimizer on kvstore
        kvStore.set_optimizer(optimizer)
        kvstore_update_period = args.kvstore_update_period
    else:
        updater = mx.optimizer.get_updater(optimizer)

    # if args.kv_type != None:
    #     kvType = args.kv_type
    #     kvStore = kvstore.create(kvType)
    #     kvStore.init(0,testParam)
    #     testOptimizer = mx.optimizer.Nop()
    #     kvStore.set_optimizer(testOptimizer)
    #     kvstore_update_period = args.kvstore_update_period


    qnet.print_stat()
    target_qnet.print_stat()
    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    while(1):
        time_before_wait = time.time()

        # kvStore.push(0,testGrad,priority=0)
        # kvStore.pull(0,testParam,priority=0)
        # testParam.wait_to_read()

        for paramIndex in range(len(qnet.params)):#range(6):#
            k=qnet.params.keys()[paramIndex]
            kvStore.push(paramIndex,qnet.params_grad[k],priority=-paramIndex)
            kvStore.pull(paramIndex,qnet.params[k],priority=-paramIndex)

        for v in qnet.params.values():
            v.wait_to_read()
        logging.info("wait time %f" %(time.time()-time_before_wait))

    for epoch in xrange(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        game.start()
        while steps_left > 0:
            # Running New Episode
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            time_episode_start = time.time()
            game.begin_episode(steps_left)
            while not game.episode_terminate:
                # 1. We need to choose a new action based on the current game status
                if game.state_enabled and game.replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        current_state = game.current_state()
                        state = nd.array(current_state.reshape((1,) + current_state.shape),
                                         ctx=q_ctx) / float(255.0)
                        qval_npy = qnet.forward(batch_size=1, data=state)[0].asnumpy()
                        action = numpy.argmax(qval_npy)
                        episode_q_value += qval_npy[0, action]
                        episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)

                # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
                game.play(action)
                total_steps += 1

                # 3. Update our Q network if we can start sampling from the replay memory
                #    Also, we update every `update_interval`
                if total_steps % update_interval == 0 and game.replay_memory.sample_enabled:
                    # 3.1 Draw sample from the replay_memory
                    training_steps += 1
                    episode_update_step += 1
                    states, actions, rewards, next_states, terminate_flags \
                        = game.replay_memory.sample(batch_size=minibatch_size)
                    states = nd.array(states, ctx=q_ctx) / float(255.0)
                    next_states = nd.array(next_states, ctx=q_ctx) / float(255.0)
                    actions = nd.array(actions, ctx=q_ctx)
                    rewards = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    # 3.2 Use the target network to compute the scores and
                    #     get the corresponding target rewards
                    if not args.double_q:
                        target_qval = target_qnet.forward(batch_size=minibatch_size,
                                                         data=next_states)[0]
                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(target_qval))\
                                           * (1.0 - terminate_flags) * discount
                    else:
                        target_qval = target_qnet.forward(batch_size=minibatch_size,
                                                         data=next_states)[0]
                        qval = qnet.forward(batch_size=minibatch_size, data=next_states)[0]

                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(qval))\
                                           * (1.0 - terminate_flags) * discount
                    outputs = qnet.forward(batch_size=minibatch_size,is_train=True, data=states,
                                              dqn_action=actions,
                                              dqn_reward=target_rewards)
                    qnet.backward(batch_size=minibatch_size)
                    nd.waitall()
                    time_before_update = time.time()

                    if args.kv_type != None:
                        if total_steps % kvstore_update_period == 0:
                            update_to_kvstore(kvStore,qnet.params,qnet.params_grad)
                    else:
                        qnet.update(updater=updater)
                    logging.info("update time %f" %(time.time()-time_before_update))
                    time_before_wait = time.time()
                    nd.waitall()
                    logging.info("wait time %f" %(time.time()-time_before_wait))

                    '''nd.waitall()
                    time_before_wait = time.time()
                    kvStore.push(0,testGrad,priority=0)
                    kvStore.pull(0,testParam,priority=0)
                    nd.waitall()
                    logging.info("wait time %f" %(time.time()-time_before_wait))'''
                    # 3.3 Calculate Loss
                    diff = nd.abs(nd.choose_element_0index(outputs[0], actions) - target_rewards)
                    quadratic_part = nd.clip(diff, -1, 1)
                    loss = (0.5 * nd.sum(nd.square(quadratic_part)) + nd.sum(diff - quadratic_part)).asscalar()
                    episode_loss += loss

                    # 3.3 Update the target network every freeze_interval
                    # (We can do annealing instead of hard copy)
                    if training_steps % freeze_interval == 0:
                        qnet.copy_params_to(target_qnet)
            steps_left -= game.episode_step
            time_episode_end = time.time()
            # Update the statistics
            epoch_reward += game.episode_reward
            info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                        % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                           game.episode_step / (time_episode_end - time_episode_start), eps_curr)
            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (episode_loss / episode_update_step,
                                                  episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d" % (episode_q_value / episode_action_step,
                                                  episode_action_step)
            logging.info(info_str)
        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d"
                     % (epoch, fps, epoch_reward / float(episode), episode))
Ejemplo n.º 28
0
def test_RNN_class(typ="lstm", ret_typ="out"):
    num_hidden = (128, 128, 128)
    data_dim = 512
    seq_len = 20
    minibatch_size = 8
    data_npy = numpy.random.standard_normal(
        (seq_len, minibatch_size, data_dim))
    init_h_0 = numpy.random.standard_normal((minibatch_size, num_hidden[0]))
    init_h_1 = numpy.random.standard_normal((minibatch_size, num_hidden[1]))
    init_h_2 = numpy.random.standard_normal((minibatch_size, num_hidden[2]))

    if typ == "lstm":
        init_c_0 = numpy.random.standard_normal(
            (minibatch_size, num_hidden[0]))
        init_c_1 = numpy.random.standard_normal(
            (minibatch_size, num_hidden[1]))
        init_c_2 = numpy.random.standard_normal(
            (minibatch_size, num_hidden[2]))
    param_shapes = get_rnn_param_shapes(num_hidden=num_hidden,
                                        data_dim=data_dim,
                                        typ=typ)
    i2h_weight_npy = [
        numpy.random.standard_normal(s) for s in param_shapes["i2h_weight"]
    ]
    i2h_bias_npy = [
        numpy.random.standard_normal(s) / 100 for s in param_shapes["i2h_bias"]
    ]
    h2h_weight_npy = [
        numpy.random.standard_normal(s) for s in param_shapes["h2h_weight"]
    ]
    h2h_bias_npy = [
        numpy.random.standard_normal(s) / 100 for s in param_shapes["h2h_bias"]
    ]
    if ret_typ == "out":
        out_grad_npy = numpy.random.standard_normal(
            (seq_len, minibatch_size, num_hidden[2]))
    elif ret_typ == "state":
        mult = 2 if typ == "lstm" else 1
        out_grad_npy = numpy.random.standard_normal(
            (minibatch_size,
             mult * (num_hidden[0] + num_hidden[1] + num_hidden[2])))
    data = mx.sym.Variable("data")
    rnn = RNN(data_dim=data_dim,
              num_hidden=num_hidden,
              typ=typ,
              name=typ.upper())
    rnn_cudnn = RNN(data_dim=data_dim,
                    num_hidden=num_hidden,
                    typ=typ,
                    cudnn_opt=True,
                    name=typ.upper() + "-cudnn")
    if ret_typ == "state":
        if typ == "lstm":
            rnn_out_h, rnn_out_c = rnn.step(data=data,
                                            seq_len=seq_len,
                                            ret_typ=ret_typ)
            rnn_cudnn_out_h, rnn_cudnn_out_c = rnn_cudnn.step(data=data,
                                                              seq_len=seq_len,
                                                              ret_typ=ret_typ)
            rnn_sym = mx.sym.Concat(*(rnn_out_h + rnn_out_c),
                                    num_args=len(num_hidden) * 2,
                                    dim=1)
            rnn_cudnn_sym = mx.sym.Concat(*(rnn_cudnn_out_h + rnn_cudnn_out_c),
                                          num_args=len(num_hidden) * 2,
                                          dim=1)
        else:
            rnn_out_h = rnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ)
            rnn_cudnn_out_h = rnn_cudnn.step(data=data,
                                             seq_len=seq_len,
                                             ret_typ=ret_typ)
            rnn_sym = mx.sym.Concat(*(rnn_out_h),
                                    num_args=len(num_hidden),
                                    dim=1)
            rnn_cudnn_sym = mx.sym.Concat(*(rnn_cudnn_out_h),
                                          num_args=len(num_hidden),
                                          dim=1)
    else:
        rnn_out_h = rnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ)
        rnn_cudnn_out_h = rnn_cudnn.step(data=data,
                                         seq_len=seq_len,
                                         ret_typ=ret_typ)
        rnn_sym = rnn_out_h[-1]
        rnn_cudnn_sym = rnn_cudnn_out_h[-1]

    if typ == "lstm":
        rnn_exe = rnn_sym.simple_bind(
            ctx=mx.gpu(),
            **{
                'data': (seq_len, minibatch_size, data_dim),
                rnn.name + '->layer0:init_h': (minibatch_size, num_hidden[0]),
                rnn.name + '->layer0:init_c': (minibatch_size, num_hidden[0]),
                rnn.name + '->layer1:init_h': (minibatch_size, num_hidden[1]),
                rnn.name + '->layer1:init_c': (minibatch_size, num_hidden[1]),
                rnn.name + '->layer2:init_h': (minibatch_size, num_hidden[2]),
                rnn.name + '->layer2:init_c': (minibatch_size, num_hidden[2])
            })
        rnn_cudnn_exe = rnn_cudnn_sym.simple_bind(
            ctx=mx.gpu(),
            **{
                'data': (seq_len, minibatch_size, data_dim),
                rnn_cudnn.name + '->layer0:init_h':
                (minibatch_size, num_hidden[0]),
                rnn_cudnn.name + '->layer0:init_c':
                (minibatch_size, num_hidden[0]),
                rnn_cudnn.name + '->layer1:init_h':
                (minibatch_size, num_hidden[1]),
                rnn_cudnn.name + '->layer1:init_c':
                (minibatch_size, num_hidden[1]),
                rnn_cudnn.name + '->layer2:init_h': (minibatch_size,
                                                     num_hidden[2]),
                rnn_cudnn.name + '->layer2:init_c': (minibatch_size,
                                                     num_hidden[2])
            })
    else:
        rnn_exe = rnn_sym.simple_bind(ctx=mx.gpu(),
                                      **{
                                          'data':
                                          (seq_len, minibatch_size, data_dim),
                                          rnn.name + '->layer0:init_h':
                                          (minibatch_size, num_hidden[0]),
                                          rnn.name + '->layer1:init_h':
                                          (minibatch_size, num_hidden[1]),
                                          rnn.name + '->layer2:init_h':
                                          (minibatch_size, num_hidden[2])
                                      })
        rnn_cudnn_exe = rnn_cudnn_sym.simple_bind(
            ctx=mx.gpu(),
            **{
                'data': (seq_len, minibatch_size, data_dim),
                rnn_cudnn.name + '->layer0:init_h':
                (minibatch_size, num_hidden[0]),
                rnn_cudnn.name + '->layer1:init_h':
                (minibatch_size, num_hidden[1]),
                rnn_cudnn.name + '->layer2:init_h':
                (minibatch_size, num_hidden[2])
            })
    for i in range(len(num_hidden)):
        rnn_exe.arg_dict[rnn.name +
                         '->layer%d:i2h_weight' % i][:] = i2h_weight_npy[i]
        rnn_exe.arg_dict[rnn.name +
                         '->layer%d:h2h_weight' % i][:] = h2h_weight_npy[i]
        rnn_exe.arg_dict[rnn.name +
                         '->layer%d:i2h_bias' % i][:] = i2h_bias_npy[i]
        rnn_exe.arg_dict[rnn.name +
                         '->layer%d:h2h_bias' % i][:] = h2h_bias_npy[i]
        rnn_cudnn_exe.arg_dict[rnn_cudnn.name + '->layer%d:i2h_weight' %
                               i][:] = i2h_weight_npy[i]
        rnn_cudnn_exe.arg_dict[rnn_cudnn.name + '->layer%d:h2h_weight' %
                               i][:] = h2h_weight_npy[i]
        rnn_cudnn_exe.arg_dict[rnn_cudnn.name +
                               '->layer%d:i2h_bias' % i][:] = i2h_bias_npy[i]
        rnn_cudnn_exe.arg_dict[rnn_cudnn.name +
                               '->layer%d:h2h_bias' % i][:] = h2h_bias_npy[i]
    N = 1
    if typ == "lstm":
        start = time.time()
        for j in range(N):
            rnn_outputs = rnn_exe.forward(is_train=True,
                                          **{
                                              "data":
                                              data_npy,
                                              rnn.name + '->layer0:init_h':
                                              init_h_0,
                                              rnn.name + '->layer0:init_c':
                                              init_c_0,
                                              rnn.name + '->layer1:init_h':
                                              init_h_1,
                                              rnn.name + '->layer1:init_c':
                                              init_c_1,
                                              rnn.name + '->layer2:init_h':
                                              init_h_2,
                                              rnn.name + '->layer2:init_c':
                                              init_c_2
                                          })

            rnn_exe.backward(
                out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())])
            nd.waitall()
        end = time.time()
        print("MXNet %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000))
        start = time.time()
        for j in range(N):
            rnn_cudnn_outputs = rnn_cudnn_exe.forward(
                is_train=True,
                **{
                    "data": data_npy,
                    rnn_cudnn.name + '->layer0:init_h': init_h_0,
                    rnn_cudnn.name + '->layer0:init_c': init_c_0,
                    rnn_cudnn.name + '->layer1:init_h': init_h_1,
                    rnn_cudnn.name + '->layer1:init_c': init_c_1,
                    rnn_cudnn.name + '->layer2:init_h': init_h_2,
                    rnn_cudnn.name + '->layer2:init_c': init_c_2
                })
            rnn_cudnn_exe.backward(
                out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())])
            nd.waitall()
        end = time.time()
        print("CuDNN %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000))
    else:
        start = time.time()
        for j in range(N):
            rnn_outputs = rnn_exe.forward(is_train=True,
                                          **{
                                              "data":
                                              data_npy,
                                              rnn.name + '->layer0:init_h':
                                              init_h_0,
                                              rnn.name + '->layer1:init_h':
                                              init_h_1,
                                              rnn.name + '->layer2:init_h':
                                              init_h_2
                                          })
            rnn_exe.backward(
                out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())])
            nd.waitall()
        end = time.time()
        print("MXNet %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000))
        start = time.time()
        for j in range(N):
            rnn_cudnn_outputs = rnn_cudnn_exe.forward(
                is_train=True,
                **{
                    "data": data_npy,
                    rnn_cudnn.name + '->layer0:init_h': init_h_0,
                    rnn_cudnn.name + '->layer1:init_h': init_h_1,
                    rnn_cudnn.name + '->layer2:init_h': init_h_2
                })
            rnn_cudnn_exe.backward(
                out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())])
            nd.waitall()
        end = time.time()
        print("CuDNN %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000))
    print(
        numpy.square(rnn_outputs[0].asnumpy() -
                     rnn_cudnn_outputs[0].asnumpy()).mean())
    for k, v in rnn_exe.grad_dict.items():
        if k == 'data':
            #numpy.testing.assert_allclose(v.asnumpy(), rnn_cudnn_exe.grad_dict[k].asnumpy())
            print(k, reldiff(v.asnumpy(),
                             rnn_cudnn_exe.grad_dict[k].asnumpy()))
        else:
            postfix = k[k.find("->"):]
            #numpy.testing.assert_allclose(v.asnumpy(), rnn_cudnn_exe.grad_dict[rnn_cudnn.name + postfix].asnumpy())
            print(
                k,
                reldiff(
                    v.asnumpy(), rnn_cudnn_exe.grad_dict[rnn_cudnn.name +
                                                         postfix].asnumpy()))
Ejemplo n.º 29
0
def train(cep,
          pool_size,
          epochs,
          train_data,
          val_data,
          ctx,
          netEn,
          netDe,
          netD,
          netD2,
          netDS,
          trainerEn,
          trainerDe,
          trainerD,
          trainerD2,
          trainerSD,
          lambda1,
          batch_size,
          expname,
          append=True,
          useAE=False):
    tp_file = open(expname + "_trainloss.txt", "w")
    tp_file.close()
    text_file = open(expname + "_validtest.txt", "w")
    text_file.close()
    #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf,  opt.ngf, opt.append)
    GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    L1_loss = gluon.loss.L2Loss()
    image_pool = imagePool.ImagePool(pool_size)
    metric = mx.metric.CustomMetric(facc)
    metric2 = mx.metric.CustomMetric(facc)
    metricStrong = mx.metric.CustomMetric(facc)
    metricMSE = mx.metric.MSE()
    loss_rec_G = []
    loss_rec_D = []
    loss_rec_R = []
    acc_rec = []
    acc2_rec = []
    loss_rec_D2 = []
    loss_rec_G2 = []
    lr = 2.0 * 512
    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)
    if cep == -1:
        cep = 0
    else:
        netEn.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_En.params',
                          ctx=ctx)
        netDe.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_De.params',
                          ctx=ctx)
        netD.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                         '_D.params',
                         ctx=ctx)
        netD2.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_D2.params',
                          ctx=ctx)
        netDS.load_params('checkpoints/' + opt.expname + '_' + str(cep) +
                          '_SD.params',
                          ctx=ctx)
    iter = 0
    for epoch in range(cep + 1, epochs):

        tic = time.time()
        btic = time.time()
        train_data.reset()
        #print('learning rate : '+str(trainerD.learning_rate ))
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            if ctx == mx.cpu():
                ct = mx.cpu()
            else:
                ct = mx.gpu()
            real_in = batch.data[0]  #.as_in_context(ctx)
            real_out = batch.data[1]  #.as_in_context(ctx)
            if iter == 0:
                latent_shape = (batch_size, 512, 1, 1)  #code.shape
                out_l_shape = (batch_size, 1, 1, 1)  #netD2((code)).shape
                out_i_shape = (batch_size, 1, 1, 1)  #netD(netDe(code)).shape
                out_s_shape = (batch_size, 1, 1, 1)  #netSD(netDe(code)).shape
            real_in = gluon.utils.split_and_load(real_in, ctx)
            real_out = gluon.utils.split_and_load(real_out, ctx)
            fake_latent = [netEn(r) for r in real_in]
            real_latent = nd.random.uniform(low=-1, high=1, shape=latent_shape)
            real_latent = gluon.utils.split_and_load(real_latent, ctx)
            fake_out = [netDe(f) for f in fake_latent]
            fake_concat = nd.concat(real_in, fake_out,
                                    dim=1) if append else fake_out
            eps2 = nd.random.uniform(low=-1,
                                     high=1,
                                     shape=latent_shape,
                                     ctx=ct)
            eps2 = gluon.utils.split_and_load(eps2, ctx)
            if epoch > 150:  # (1/float(batch_size))*512*150:# and epoch%10==0:
                print('Mining..')
                mu = nd.random.uniform(low=-1,
                                       high=1,
                                       shape=latent_shape,
                                       ctx=ct)
                #isigma = nd.ones((batch_size,64,1,1),ctx=ctx)*0.000001
                mu.attach_grad()
                #sigma.attach_grad()
                images = netDe(mu)
                fake_img1T = nd.concat(images[0], images[1], images[2], dim=1)
                fake_img2T = nd.concat(images[3], images[4], images[5], dim=1)
                fake_img3T = nd.concat(images[6], images[7], images[8], dim=1)
                fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2)
                visual.visualize(fake_img)
                plt.savefig('outputs/' + expname + '_fakespre_' + str(epoch) +
                            '.png')
                eps2 = gluon.utils.split_and_load(mu, ctx)
                for e in eps2:
                    e.attach_grad()
                for ep2 in range(1):
                    with autograd.record():
                        #eps = nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx) #
                        #eps2 = gluon.utils.split_and_load(nd.tanh(mu),ctx) #+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx)
                        rec_output = [netDS(netDe(e)) for e in eps2]
                        fake_label = gluon.utils.split_and_load(
                            nd.zeros(out_s_shape), ctx)
                        errGS = [
                            GAN_loss(r, f)
                            for r, f in zip(rec_output, fake_label)
                        ]
                        for e in errGS:
                            e.backward()
                    for idx, _ in enumerate(eps2):
                        eps2[idx] = nd.tanh(eps2[idx] -
                                            lr / eps2[idx].shape[0] *
                                            eps2[idx].grad)
                images = netDe((eps2[0]))
                fake_img1T = nd.concat(images[0], images[1], images[2], dim=1)
                fake_img2T = nd.concat(images[3], images[4], images[5], dim=1)
                fake_img3T = nd.concat(images[6], images[7], images[8], dim=1)
                fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2)
                visual.visualize(fake_img)
                plt.savefig('outputs/' + expname + str(ep2) + '_fakespost_' +
                            str(epoch) + '.png')
                #eps2 = nd.tanh(mu)#+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx)

            with autograd.record():
                #eps2 = gluon.utils.split_and_load(eps2,ctx)
                # Train with fake image
                # Use image pooling to utilize history imagesi
                output = [netD(f) for f in fake_concat]
                output2 = [netD2(f) for f in fake_latent]
                fake_label = nd.zeros(out_i_shape)
                fake_label = gluon.utils.split_and_load(fake_label, ctx)
                fake_latent_label = nd.zeros(out_l_shape)
                fake_latent_label = gluon.utils.split_and_load(
                    fake_latent_label, ctx)
                eps = gluon.utils.split_and_load(
                    nd.random.uniform(low=-1, high=1, shape=latent_shape), ctx)
                rec_output = [netD(netDe(e)) for e in eps]
                errD_fake = [
                    GAN_loss(r, f) for r, f in zip(rec_output, fake_label)
                ]
                errD_fake2 = [
                    GAN_loss(o, f) for o, f in zip(output, fake_label)
                ]
                errD2_fake = [
                    GAN_loss(o, f) for o, f in zip(output2, fake_latent_label)
                ]
                for f, o in zip(fake_label, rec_output):
                    metric.update([
                        f,
                    ], [
                        o,
                    ])
                for f, o in zip(fake_latent_label, output2):
                    metric2.update([
                        f,
                    ], [
                        o,
                    ])
                real_concat = nd.concat(real_in, real_out,
                                        dim=1) if append else real_out
                output = [netD(r) for r in real_concat]
                output2 = [netD2(r) for r in real_latent]
                real_label = gluon.utils.split_and_load(
                    nd.ones(out_i_shape), ctx)
                real_latent_label = gluon.utils.split_and_load(
                    nd.ones(out_l_shape), ctx)
                errD_real = [
                    GAN_loss(o, r) for o, r in zip(output, real_label)
                ]
                errD2_real = [
                    GAN_loss(o, r) for o, r in zip(output2, real_latent_label)
                ]
                for e1, e2, e4, e5 in zip(errD_real, errD_fake, errD2_real,
                                          errD2_fake):
                    err = (e1 + e2) * 0.5 + (e5 + e4) * 0.5
                    err.backward()
                for f, o in zip(real_label, output):
                    metric.update([
                        f,
                    ], [
                        o,
                    ])
                for f, o in zip(real_latent_label, output2):
                    metric2.update([
                        f,
                    ], [
                        o,
                    ])
            trainerD.step(batch.data[0].shape[0])
            trainerD2.step(batch.data[0].shape[0])
            nd.waitall()
            with autograd.record():
                strong_output = [netDS(netDe(e)) for e in eps]
                strong_real = [netDS(f) for f in fake_concat]
                errs1 = [
                    GAN_loss(r, f) for r, f in zip(strong_output, fake_label)
                ]
                errs2 = [
                    GAN_loss(r, f) for r, f in zip(strong_real, real_label)
                ]
                for f, s in zip(fake_label, strong_output):
                    metricStrong.update([
                        f,
                    ], [
                        s,
                    ])
                for f, s in zip(real_label, strong_real):
                    metricStrong.update([
                        f,
                    ], [
                        s,
                    ])
                for e1, e2 in zip(errs1, errs2):
                    strongerr = 0.5 * (e1 + e2)
                    strongerr.backward()
            trainerSD.step(batch.data[0].shape[0])
            nd.waitall()
            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            with autograd.record():
                sh = out_l_shape
                #eps2 = nd.random_normal(loc=0, scale=1, shape=noiseshape, ctx=ctx) #
                #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
                #if epoch>100:
                #        eps2 = nd.multiply(eps2,sigma)+mu
                #        eps2 = nd.tanh(eps2)
                #else:
                #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx)
                #eps2 = nd.concat(eps,eps2,dim=0)
                rec_output = [netD(netDe(e)) for e in eps2]
                fake_latent = [(netEn(r)) for r in real_in]
                output2 = [netD2(f) for f in fake_latent]
                fake_out = [netDe(f) for f in fake_latent]
                fake_concat = nd.concat(real_in, fake_out,
                                        dim=1) if append else fake_out
                output = [netD(f) for f in fake_concat]
                real_label = gluon.utils.split_and_load(
                    nd.ones(out_i_shape), ctx)
                real_latent_label = gluon.utils.split_and_load(
                    nd.ones(out_l_shape), ctx)
                errG2 = [
                    GAN_loss(r, f) for r, f in zip(rec_output, real_label)
                ]
                errR = [
                    L1_loss(r, f) * lambda1
                    for r, f in zip(real_out, fake_out)
                ]
                errG = [
                    10 * GAN_loss(r, f)
                    for r, f in zip(output2, real_latent_label)
                ]  # +errG2+errR
                for e1, e2, e3 in zip(errG, errG2, errR):
                    e = e1 + e2 + e3
                    e.backward()
            trainerDe.step(batch.data[0].shape[0])
            trainerEn.step(batch.data[0].shape[0])
            nd.waitall()
            errD = (errD_real[0] + errD_fake[0]) * 0.5
            errD2 = (errD2_real[0] + errD2_fake[0]) * 0.5
            loss_rec_G2.append(nd.mean(errG2[0]).asscalar())
            loss_rec_G.append(
                nd.mean(nd.mean(errG[0])).asscalar() -
                nd.mean(errG2[0]).asscalar() - nd.mean(errR[0]).asscalar())
            loss_rec_D.append(nd.mean(errD[0]).asscalar())
            loss_rec_R.append(nd.mean(errR[0]).asscalar())
            loss_rec_D2.append(nd.mean(errD2[0]).asscalar())
            _, acc2 = metric2.get()
            name, acc = metric.get()
            acc_rec.append(acc)
            acc2_rec.append(acc2)

            # Print log infomation every ten batches
            if iter % 10 == 0:
                _, acc2 = metric2.get()
                name, acc = metric.get()
                _, accStrong = metricStrong.get()
                logging.info('speed: {} samples/s'.format(
                    batch_size / (time.time() - btic)))
                #print(errD)
                #logging.info('discriminator loss = %f, D2 loss = %f, generator loss = %f, G2 loss = %f, SD loss = %f,  D acc = %f , D2 acc = %f, DS acc = %f, reconstruction error= %f  at iter %d epoch %d'
                #   	% (nd.mean(errD[0]).asscalar(),nd.mean(errD2[0]).asscalar(),
                #     	nd.mean(errG[0]-errG2[0]-errR[0]).asscalar(),nd.mean(errG2[0]).asscalar(),nd.mean(strongerr[0]).asscalar() ,acc,acc2,accStrong[0],nd.mean(errR[0]).asscalar() ,iter, epoch))
                iter = iter + 1
        btic = time.time()
        name, acc = metric.get()
        _, acc2 = metric2.get()
        #tp_file = open(expname + "_trainloss.txt", "a")
        #tp_file.write(str(nd.mean(errG2).asscalar()) + " " + str(
        #    nd.mean(nd.mean(errG)).asscalar() - nd.mean(errG2).asscalar() - nd.mean(errR).asscalar()) + " " + str(
        #    nd.mean(errD).asscalar()) + " " + str(nd.mean(errD2).asscalar()) + " " + str(nd.mean(errR).asscalar()) +" "+str(acc) + " " + str(acc2)+"\n")
        #tp_file.close()
        metric.reset()
        metric2.reset()
        train_data.reset()
        metricStrong.reset()

        logging.info('\nbinary training acc at epoch %d: %s=%f' %
                     (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))
        if epoch % 2 == 0:  # and epoch>0:
            text_file = open(expname + "_validtest.txt", "a")
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_D.params"
            netD.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_D2.params"
            netD2.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_En.params"
            netEn.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_De.params"
            netDe.save_parameters(filename)
            filename = "checkpoints/" + expname + "_" + str(
                epoch) + "_SD.params"
            netDS.save_parameters(filename)
            fake_img1 = nd.concat(real_in[0], real_out[0], fake_out[0], dim=1)
            fake_img2 = nd.concat(real_in[1], real_out[1], fake_out[1], dim=1)
            fake_img3 = nd.concat(real_in[2], real_out[2], fake_out[2], dim=1)
            fake_img4 = nd.concat(real_in[3], real_out[3], fake_out[3], dim=1)
            val_data.reset()
            text_file = open(expname + "_validtest.txt", "a")
            for vbatch in val_data:

                real_in = vbatch.data[0]
                real_out = vbatch.data[1]
                real_in = gluon.utils.split_and_load(real_in, ctx)
                real_out = gluon.utils.split_and_load(real_out, ctx)

                fake_latent = [netEn(r) for r in real_in]
                fake_out = [netDe(f) for f in fake_latent]
                for f, r in zip(fake_out, real_out):
                    metricMSE.update([
                        f,
                    ], [
                        r,
                    ])
            _, acc2 = metricMSE.get()
            toterrR = 0
            for e in errR:
                toterrR += nd.mean(e).asscalar()
            text_file.write("%s %s %s\n" % (str(epoch), toterrR, str(acc2)))
            metricMSE.reset()
    return ([
        loss_rec_D, loss_rec_G, loss_rec_R, acc_rec, loss_rec_D2, loss_rec_G2,
        acc2_rec
    ])
Ejemplo n.º 30
0
def train(opt):
    sw = SummaryWriter(logdir='./logs', flush_secs=5)
    decay_every = int(opt.n_epoch / 2)

    if opt.experiment is None:
        opt.experiment = 'samples'
    os.system('mkdir {}'.format(opt.experiment))

    if opt.gpu_ids == '-1':
        context = [mx.cpu()]
    else:
        #context = mx.gpu(7)
        context = [mx.gpu(int(i)) for i in opt.gpu_ids.split(',') if i.strip()]
    print("context: {}".format(context))

    features = load_vgg_model_features(ctx_list=context, last_layer=28)

    ##### Prapare data for training or validation #####
    dataset = DataSet(
        opt.dataroot, RandomCrop(opt.fineSize),
        transforms.Resize(int(opt.fineSize / 4), interpolation=3),
        transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]))
    dataloader = DataLoader(dataset,
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=int(opt.workers),
                            last_batch='rollover')

    ##### Build Network #####
    netG = SRGenerator()
    netG.initialize(ctx=context[0])
    netD = SRDiscriminator()
    netD.initialize(ctx=context[0])

    # Enforce non-deferred initialization by one forward pass computation
    dummy_in = nd.random.uniform(
        0,
        1, (1, 3, int(opt.fineSize / 4), int(opt.fineSize / 4)),
        ctx=context[0])
    netD(netG(dummy_in))

    # Our own re-setting on parameters
    weights_init(netG.collect_params())
    netG.collect_params().reset_ctx(context)
    weights_init(netD.collect_params())
    netD.collect_params().reset_ctx(context)

    optimizer_G = gluon.Trainer(params=netG.collect_params(),
                                optimizer='adam',
                                optimizer_params={
                                    'learning_rate': opt.lr_init,
                                    'beta1': opt.beta1
                                },
                                kvstore='local')
    optimizer_D = gluon.Trainer(params=netD.collect_params(),
                                optimizer='adam',
                                optimizer_params={
                                    'learning_rate': opt.lr_init,
                                    'beta1': opt.beta1
                                },
                                kvstore='local')

    ##### Stage 1/2 of Training Process #####
    # Pre-train Generator G to avoid undesired local optima when training SRGAN.
    print("Start pre-train Generator ...")
    param_file = os.path.join(opt.experiment, 'netG_init_epoch.param')
    if os.path.exists(param_file):
        print(
            "Load existed parameter file pre-trained: {}, skip the pre-train process."
            .format(param_file))
        netG.load_parameters(param_file, ctx=context)
    else:
        print("No existed parameter file, keep going to pre-train.")
        for epoch in range(opt.n_epoch_init):
            start = time.time()
            batch = 0
            for hr_img_iter, lr_img_iter in dataloader:
                #hr_img = hr_img.as_in_context(context)
                #lr_img = lr_img.as_in_context(context)
                hr_imgs = gluon.utils.split_and_load(hr_img_iter,
                                                     ctx_list=context)
                lr_imgs = gluon.utils.split_and_load(lr_img_iter,
                                                     ctx_list=context)

                with autograd.record():
                    ls = [
                        mse_loss(hr_img, netG(lr_img))
                        for hr_img, lr_img in zip(hr_imgs, lr_imgs)
                    ]
                for l in ls:
                    l.backward()

#                with autograd.record():
#                    hr_img_pred = netG(mx.nd.array(lr_img))
#                    loss = mse_loss(mx.nd.array(hr_img), hr_img_predit)
#                    autograd.backward(loss)
                optimizer_G.step(opt.batchSize)
                print("Epoch %d:  Batch %d:  mse: %.8f" %
                      (epoch, batch, ls[-1].mean().asscalar()))
                batch += opt.batchSize
            nd.waitall()
            train_time = time.time() - start
            print("Epoch %d:  mse: %.8f  trainning time:%.1f sec" %
                  (epoch, ls[-1].mean().asscalar(), train_time))
            if epoch % 20 == 0:
                netG.save_parameters('{0}/netG_init_epoch_{1}.param'.format(
                    opt.experiment, epoch))
            if epoch == opt.n_epoch_init - 1:
                netG.save_parameters('{0}/netG_init_epoch.param'.format(
                    opt.experiment))
    print("Pre-train Generator finished ...")

    ##### Stage 2/2 of Training Process #####
    # Jointly optimize G and D, namely train SRGAN.
    print("Start to train SRGAN ...")
    mean_mask = nd.zeros((opt.batchSize, 3, opt.fineSize, opt.fineSize),
                         ctx=context[0])
    mean_mask[:, 0, :, :] = 0.485
    mean_mask[:, 1, :, :] = 0.456
    mean_mask[:, 2, :, :] = 0.406
    std_mask = nd.zeros((opt.batchSize, 3, opt.fineSize, opt.fineSize),
                        ctx=context[0])
    std_mask[:, 0, :, :] = 0.229
    std_mask[:, 1, :, :] = 0.224
    std_mask[:, 2, :, :] = 0.225

    real_label = nd.ones((opt.batchSize, ), ctx=context[0])
    fake_label = nd.zeros((opt.batchSize, ), ctx=context[0])

    mean_masks = mx.gluon.utils.split_and_load(mean_mask, ctx_list=context)
    std_masks = mx.gluon.utils.split_and_load(std_mask, ctx_list=context)
    real_labels = mx.gluon.utils.split_and_load(real_label, ctx_list=context)
    fake_labels = mx.gluon.utils.split_and_load(fake_label, ctx_list=context)

    loss_d = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    losses_log = loss_dict()

    for epoch in range(0, opt.n_epoch):
        start = time.time()
        batch = 0
        train_errD = 0
        train_errG = 0
        for hr_img_iter, lr_img_iter in dataloader:
            losses_log.reset()

            hr_imgs = gluon.utils.split_and_load(hr_img_iter, ctx_list=context)
            lr_imgs = gluon.utils.split_and_load(lr_img_iter, ctx_list=context)
            hr_fake_imgs = []

            # Step1. Optimize D
            # Step2. Optimize G
            batch_errD = []
            batch_errG = []
            print("Optimize D in a Batch...")
            with autograd.record():
                for hr_img, lr_img, mean_mask, std_mask, real_label, fake_label in zip(
                        hr_imgs, lr_imgs, mean_masks, std_masks, real_labels,
                        fake_labels):
                    # errD computation
                    output = netD(hr_img).reshape((-1, 1))
                    errD_real = loss_d(output, real_label)
                    hr_img_fake = netG(lr_img)
                    hr_fake_imgs.append(hr_img_fake)
                    output = netD(hr_img_fake.detach()).reshape((-1, 1))
                    errD_fake = loss_d(output, fake_label)
                    errD = errD_real + errD_fake
                    batch_errD.append(errD)
                    losses_log.add(lr_img=lr_img,
                                   hr_img=hr_img,
                                   hr_img_fake=hr_img_fake)

                # run backward on batch errD and update parameters
                autograd.backward(batch_errD)
            optimizer_D.step(opt.batchSize)

            print("Optimize G in a Batch...")
            with autograd.record():
                for hr_img, lr_img, hr_img_fake, mean_mask, std_mask, real_label, fake_label in zip(
                        hr_imgs, lr_imgs, hr_fake_imgs, mean_masks, std_masks,
                        real_labels, fake_labels):
                    # errG computation
                    errM = mse_loss(hr_img_fake, hr_img)
                    input_fake = ((hr_img_fake + 1) / 2 - mean_mask) / std_mask
                    fake_emb = vgg_feature(input_fake, features)
                    input_real = ((hr_img + 1) / 2 - mean_mask) / std_mask
                    real_emb = vgg_feature(input_real, features)
                    errV = 6e-3 * mse_loss(fake_emb, real_emb)
                    output = netD(hr_img_fake).reshape((-1, 1))
                    errA = 1e-3 * loss_d(output, real_label)
                    errG = errM + errV + errA
                    batch_errG.append(errG)

            # run backward on batch errG and update parameters
            autograd.backward(batch_errG)
            #            for errG in batch_errG:
            #                errG.backward()
            #                losses_log.add(errG=errG, errM=errM, errV=errV, errA=errA)
            optimizer_G.step(opt.batchSize)

            # sum losses over all devices
            train_errD += sum([errD.sum().asscalar() for errD in batch_errD])
            train_errG += sum([errG.sum().asscalar() for errG in batch_errG])
            print(
                "Epoch:%d, Batch:%d  -----   D-Loss = %.3f, G-Loss = %.3f    (Time %.1f sec)"
                % (epoch, batch * opt.batchSize, train_errD, train_errG,
                   time.time() - start))
            batch += 1

            plot_loss(sw, losses_log,
                      epoch * len(dataloader) + batch, epoch, batch)

        if epoch != 0 and (epoch % decay_every == 0):
            optimizer_G.set_learning_rate(optimizer_G.learning_rate *
                                          opt.lr_decay)
            optimizer_D.set_learning_rate(optimizer_D.learning_rate *
                                          opt.lr_decay)
        if (epoch != 0) and (epoch % 10 == 0):
            plot_img(sw, losses_log)
            netG.save_parameters('{0}/netG_epoch_{1}.param'.format(
                opt.experiment, epoch))
            netD.save_parameters('{0}/netD_epoch_{1}.param'.format(
                opt.experiment, epoch))
    print("Train SRGAN finished ...")