def nd_forward_backward_and_profile(op, runs, *args, **kwargs): """Helper function to run a given NDArray operator (op) for 'runs' number of times with given args and kwargs. Executes both forward and backward pass. NOTE: This is a sync call and waits for all the operations execution to complete. Parameters ---------- op: Str NDArray operator (Function reference) to execute. Example: mx.nd.add runs: int Number of times to execute the operation args: Arguments for the NDArray operator (op) being executed. kwargs: Key value arguments for the NDArray operator (op) being executed. Returns ------- any results from NDArray operation execution """ for _ in range(runs): with mx.autograd.record(): if not isinstance(args[0], nd.NDArray): res = op(**kwargs) else: res = op(*args, **kwargs) res.backward() nd.waitall() return res
def nd_forward_and_profile(op, runs, *args, **kwargs): """Helper function to run a given NDArray operator (op) for 'runs' number of times with given args and kwargs. Executes ONLY forward pass. NOTE: This is a sync call and waits for all the operations execution to complete. Parameters ---------- op: Str NDArray operator (Function reference) to execute. Example: mx.nd.add runs: int Number of time to execute the operation args: Arguments for the NDArray operator (op) being executed. kwargs: Key value arguments for the NDArray operator (op) being executed. Returns ------- any results from NDArray operation execution """ for _ in range(runs): res = op(*args, **kwargs) nd.waitall() return res
def predict(): img = pre_process_image( '/incubator-mxnet/scala-package/examples/scripts/infer/images/dog.jpg') # compute the predict probabilities import time # print img.shape data_iter = mx.io.NDArrayIter([img], None, 16) start = time.time() op = mod.predict(data_iter) #mod.forward(Batch([img])) nd.waitall() # print (type(op[0])) end = time.time() # prob = op[0] # print (op[0]) #op[0].copyto(prob) #prob = prob.asnumpy() # print (mod.get_outputs()[0].shape) # print len(mod.get_outputs()) #prob = mod.get_outputs()[0] #shape = mod.get_outputs()[0].shape print(end - start) #prob = np.squeeze(prob) #a = np.argsort(prob)[::-1] #for i in a[0:5]: # print('probability=%f, class=%s' %(prob[i], labels[i])) # print the top-5 # prob = np.squeeze(prob) # a = np.argsort(prob)[::-1] # print (len(prob)) return end - start
def test_nd_op(ctx=mx.gpu()): nd.waitall() np_arr = np.zeros((3, 1, 5, 5)) nd_arr = nd.array(np_arr, ctx) print(nd_arr) time.sleep(5) print('success')
def workflow_inference(self, instream, shape): while (True): # Capture frame-by-frame ret, frame = cap.read() frame = cv2.resize(frame, (1280, 720)) # Our operations on the frame come here #gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) out = self._retina_forward(frame) try: self.write_queue((frame, out)) except: waitall() print('Frame queue full', file=sys.stderr) # Display the resulting frame #cv2.imshow('frame',gray) if cv2.waitKey(1) & 0xFF == ord('q'): break # When everything done, release the capture cap.release() cv2.destroyAllWindows()
def _impl(op, params, graph, **kwargs): deps = kwargs['deps'] name, op_name = op.attr('name'), op.attr('op_name') childs, attr = sym_iter(op.get_children()), op.list_attr() if op_name == 'null': start_time = None out = data if is_inputs(op, params) else params[name] elif childs is None: start_time= time.time() out = get_nd_op(op_name)(**attr) if gpu_flag: nd.waitall() end_time = time.time() else: cinfos = [(c.attr('name'), get_entry_id(c)) for c in childs] nd_inputs = [out_cache[n[0]][n[1]] for n in cinfos] start_time = time.time() out = get_nd_op(op_name)(*nd_inputs, **attr) if gpu_flag: nd.waitall() end_time = time.time() for n, _ in cinfos: assert n in deps deps[n].remove(name) if len(deps[n]) == 0: del out_cache[n] if start_time is not None: if op_name not in times: times[op_name] = {} times[op_name][name] = end_time - start_time out = [out] if len(op) == 1 else out out_cache[name] = [o.as_in_context(ctx) for o in out]
def hybrid_forward(self, F, x, *args, **kwargs): for layer in self.net: tic = time.time() x = layer(x) nd.waitall() time0 = time.time()-tic time0 -= 0 return x
def test_model(model, inpt, amount, wait=True): tic = time.time() for i in range(amount): if amount >1 and i == 1: tic = time.time() out = model(inpt) if wait: nd.waitall() amount = amount - 1 if amount > 1 else amount time_use = (time.time() - tic)/amount return time_use, out
def do_a_train_step(self, stp=None): # NOTE: wait the weights ndarray.waitall() if not stp: stp = self.batch_size self.trainer.step(stp) # TODO: zero grad mxprms.params_zero_grad(self.net.collect_params())
def workflow_inference(self, instream, shape): for source in instream: # st = time.perf_counter() frame = frombuffer(source, dtype=uint8).reshape(shape) out = self._retina_forward(frame) try: self.write_queue((frame, out)) except: waitall() print('Frame queue full', file=sys.stderr)
def trainNet(net, trainer, train_data, loss, train_metric, epoch, config, logger, ctx): if not logger: assert False, 'require a logger' train_data.reset() # reset and re-shuffle if train_metric: train_metric.reset() trainloss, n = [0] * len(ctx), 0 for batch_i, batch in enumerate(train_data): data_list = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) label_list = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) Ls = [] output_list = [] with autograd.record(): for x, y in zip(data_list, label_list): preds = net(x) L = loss(preds, y) Ls.append(L) output_list.append(preds) for L in Ls: L.backward() trainer.step(batch.data[0].shape[0]) # Number n += batch.data[0].shape[0] # Loss for i in range(len(trainloss)): trainloss[i] += Ls[i] nd.waitall() trainloss = sum([item.sum().asscalar() for item in trainloss]) logger.info("TRAIN - Epoch:%d LR:%.2e Loss:%.2e" % (epoch + 1, trainer.learning_rate, trainloss / n)) # save model if ((epoch + 1) % (config.TRAIN.end_epoch / 5) == 0 or epoch == 0): saveModel(net, logger, config, isCKP=True, epoch=epoch + 1) if (epoch + 1 == config.TRAIN.end_epoch): saveModel(net, logger, config, isCKP=False, epoch=epoch + 1)
def nd_forward_and_time(F, runs, *args, **kwargs): """Helper function to run a given NDArray operator (F) for 'runs' number of times with given args and kwargs. Executes ONLY forward pass. NOTE: This is a sync call and waits for all the operations execution to complete. :param F: NDArray operator (Function feference) to execute. Example: mx.nd.add :param runs: Number of time to execute the operation :param args: Arguments for the NDArray operator (F) being executed. :param kwargs: Key value arguments for the NDArray operator (F) being executed. :return: Tuple(Total execution time in seconds, any results from NDArray operation execution) """ for _ in range(runs): F(*args, **kwargs) nd.waitall()
def block_forward_backward_and_time(*args, block, runs, **kwargs): """Helper function to run a given Block (block) for 'runs' number of times with given args and kwargs. Executes both forward and backward pass. NOTE: This is a sync call and waits for all the operations execution to complete. :param block: Gluon block to execute. Example: an instance of gluon.nn.Dense(...) :param runs: Number of times to execute the block operation :param args: Arguments for the block being executed. :param kwargs: Key value arguments for the block being executed. :return: Tuple of (Total execution time in seconds, any results from block execution) """ for _ in range(runs): with mx.autograd.record(): res = block.forward(*args, **kwargs) res.backward() nd.waitall()
def _run_forward_backward_benchmark(self, runs, x): for _ in range(runs): with mx.autograd.record(): # Forward res1 = x + 1 res2 = res1 + 1 res3 = res2 + 1 res4 = nd.Custom(res3, name="customaddone", op_type="CustomAddOne") res5 = res4 + 1 res6 = res5 + 1 res7 = res6 + 1 # Backward res7.backward() nd.waitall()
def block_forward_and_time(*args, block, runs, **kwargs): """Helper function to run a given Block (block) for 'runs' number of times with given args and kwargs. Executes forward pass only. NOTE: This is a sync call and waits for all the operations execution to complete. :param block: Gluon block to execute. Example: an instance of gluon.nn.Dense(...) :param runs: Number of times to execute the block operation :param args: Arguments for the block being executed. :param kwargs: Key value arguments for the block being executed. :return: Tuple of (Total execution time in seconds, any results from block execution) """ for _ in range(runs): # Imperative Mode. This is block forward function block.hybrid_forward(F=nd, *args, **kwargs) nd.waitall()
def nd_forward_backward_and_profile(op, runs, **kwargs): """Helper function to run a given NDArray operator (op) for 'runs' number of times with given args and kwargs. Executes both forward and backward pass. NOTE: This is a sync call and waits for all the operations execution to complete. Parameters ---------- op: Str NDArray operator (Function reference) to execute. Example: mx.nd.add runs: int Number of times to execute the operation kwargs: Key value arguments for the NDArray operator (op) being executed. Returns ------- any results from NDArray operation execution """ for _ in range(runs): with mx.autograd.record(): args = [] # need to create a new dictionary because can't update dict while iterating kwargs_new = dict() for key in kwargs: # separate positional args from key-worded args if key.startswith("args"): args.append(kwargs[key]) else: kwargs_new[key] = kwargs[key] # check for positional args if len(args): res = op(*args, **kwargs_new) else: res = op(**kwargs_new) res.backward() nd.waitall() return res
def validNet(net, valid_data, loss, eval_metric, epoch, config, logger, ctx): if not logger: assert False, 'require a logger' valid_data.reset() if eval_metric: eval_metric.reset() validloss, n = [0] * len(ctx), 0 for batch_i, batch in enumerate(valid_data): data_list = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) label_list = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) Ls = [] output_list = [] for x, y in zip(data_list, label_list): preds = net(x) L = loss(preds, y) output_list.append(preds) Ls.append(L) if config.TRAIN.UseMetric: for lb, pd in zip(label_list, output_list): eval_metric.update(lb, pd) for i in range(len(validloss)): validloss[i] += Ls[i] n += batch.data[0].shape[0] nd.waitall() validloss = sum([item.sum().asscalar() for item in validloss]) MPJPE = eval_metric.get()[-1].sum(axis=0) / 17 logger.info("VALID - Epoch:%d Loss:%.3e MPJPE:%.1f" % (epoch + 1, validloss / n, MPJPE))
def _chris_update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue name = param_names[index] # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) if os.getenv("GLOBAL_BARRIER", 0) == 1: ndarray.waitall() # if os.getenv('PULL_SLEEP_TIME') is not None: # delay = float(os.getenv('PULL_SLEEP_TIME')) # time.sleep(delay) # self.logger.info("before pull in _chris_update_params_on_kvstore, time is:",time.time()) for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue name = param_names[index] # pull back the weights kvstore.pull(name, arg_list, priority=-index)
def train( args, model, train_sampler, valid_samplers=None, rank=0, rel_parts=None, barrier=None, ): assert args.num_proc <= 1, "MXNet KGE does not support multi-process now" assert (args.rel_part == False ), "No need for relation partition in single process for MXNet KGE" logs = [] for arg in vars(args): logging.info("{:20}:{}".format(arg, getattr(args, arg))) if len(args.gpu) > 0: gpu_id = (args.gpu[rank % len(args.gpu)] if args.mix_cpu_gpu and args.num_proc > 1 else args.gpu[0]) else: gpu_id = -1 if args.strict_rel_part: model.prepare_relation(mx.gpu(gpu_id)) if mxprofiler: from mxnet import profiler profiler.set_config( profile_all=True, aggregate_stats=True, continuous_dump=True, filename="profile_output.json", ) start = time.time() for step in range(0, args.max_step): pos_g, neg_g = next(train_sampler) args.step = step if step == 1 and mxprofiler: profiler.set_state("run") with mx.autograd.record(): loss, log = model.forward(pos_g, neg_g, gpu_id) loss.backward() logs.append(log) model.update(gpu_id) if step % args.log_interval == 0: for k in logs[0].keys(): v = sum(l[k] for l in logs) / len(logs) print("[Train]({}/{}) average {}: {}".format( step, args.max_step, k, v)) logs = [] print(time.time() - start) start = time.time() if (args.valid and step % args.eval_interval == 0 and step > 1 and valid_samplers is not None): start = time.time() test(args, model, valid_samplers, mode="Valid") print("test:", time.time() - start) if args.strict_rel_part: model.writeback_relation(rank, rel_parts) if mxprofiler: nd.waitall() profiler.set_state("stop") profiler.dump() print(profiler.dumps()) # clear cache logs = []
def validNet(net, valid_data, loss, eval_metric, epoch, config, logger, ctx): if not logger: assert False, 'require a logger' valid_data.reset() if eval_metric: eval_metric.reset() w = config.TRAIN.w batchsize = config.TRAIN.batchsize UseMetric = config.TRAIN.UseMetric seqLength = config.DATASET.seqLength nJoints = config.NETWORK.nJoints loss1, loss2, n1, n2 = [0] * len(ctx), [0] * len(ctx), 0.000001, 0.000001 for batch_i, batch in enumerate(valid_data): data_list = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=1) label_list = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=1) Ls1, Ls2, output_list = [], [], [] # forward for data, label, cx in zip(data_list, label_list, ctx): initial_state = [ nd.zeros(shape=(batchsize, config.NETWORK.hidden_dim), ctx=cx) for _ in range(2) ] start_token = nd.ones(shape=(batchsize, 3 * nJoints), ctx=cx) preds = net(data, initial_state, start_token) output_list.append(preds) # pred=[seqLength, 64x48] L1, L2 = 0, 0 for pd, lb in zip(preds, label): L1 = L1 + loss(pd, lb) if seqLength > 1: for i in range(1, seqLength): deltaP = preds[i] - preds[i - 1] deltaG = label[i] - label[i - 1] L2 = L2 + loss(deltaP, deltaG) Ls1.append(L1) Ls2.append(L2) if seqLength > 1 else Ls2.append(nd.zeros(1)) # number n1 = n1 + len(ctx) * batchsize * seqLength n2 = n2 + len(ctx) * batchsize * (seqLength - 1) # loss for i in range(len(loss1)): loss1[i] += Ls1[i] loss2[i] += Ls2[i] # metric, save time if UseMetric: for pred_batch, label_batch in zip( output_list, label_list): # for each timestamp for t_pred, t_label in zip(pred_batch, label_batch): eval_metric.update(t_label, t_pred) nd.waitall() loss1 = sum([item.sum().asscalar() for item in loss1]) loss2 = sum([item.sum().asscalar() for item in loss2]) validloss = loss1 / n1 + w * loss2 / n2 MPJPE = eval_metric.get()[-1].sum(axis=0) / 17 if UseMetric else 0 logger.info( "VALID - Epoch:%2d Loss1:%.2e Loss2(%2d):%.2e TotalLoss:%.2e MPJPE:%.1f" % (epoch + 1, loss1 / n1, w, loss2 / n2, validloss, MPJPE))
import mxnet as mx import mxnet.ndarray as nd import time N = 10 a = mx.random.normal(0, 1, (4, 1024, 1024), ctx=mx.gpu(0)) b = mx.nd.empty((4, 1024, 1024), ctx=mx.cpu()) c = mx.nd.empty((4, 1024, 1024), ctx=mx.gpu()) nd.waitall() a.asnumpy() start = time.time() for i in range(N): b[:] = a nd.waitall() avg_time = (time.time() - start)/N print('GPU->CPU, GB/s: %g' %(4 * a.size * 1E-9 / avg_time)) start = time.time() for i in range(N): a[:] = b nd.waitall() avg_time = (time.time() - start)/N print('CPU->GPU, GB/s: %g' %(4 * a.size * 1E-9 / avg_time)) start = time.time() for i in range(N): c[:] = a nd.waitall() avg_time = (time.time() - start)/N print('GPU->GPU, GB/s: %g' %(4 * a.size * 1E-9 / avg_time))
def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), eval_end_callback=None, eval_batch_end_callback=None, initializer=Uniform(0.01), arg_params=None, aux_params=None, allow_missing=False, force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None, validation_metric=None, monitor=None): assert num_epoch is not None, 'please specify number of epochs' self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, for_training=True, force_rebind=force_rebind) if monitor is not None: self.install_monitor(monitor) self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, allow_missing=allow_missing, force_init=force_init) self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params) if validation_metric is None: validation_metric = eval_metric if not isinstance(eval_metric, metric.EvalMetric): eval_metric = metric.create(eval_metric) ####chris_arg if int(os.getenv("TASK_LIMIT", 0)) != 0: #为0时不分task限制,为1时分task但是每轮更新,为2时分task并但固定 get_task_cmd = "sh /home/ubuntu/tc.sh -l 1" else: self.logger.info("no_task_bandwidth_limit") get_task_cmd = "sh /home/ubuntu/tc.sh -l 0" os.system(get_task_cmd) delay_time = float(os.getenv("DELAY_TIME", 0.8)) ps_upload_bandwidth_part1 = int(os.getenv("PS_UPLOAD_BANDWIDTH1", 2000)) worker_upload_bandwidth_part1 = int( os.getenv("WORKER_UPLOAD_BANDWIDTH1", 2000)) ps_upload_bandwidth_part2 = int(os.getenv("PS_UPLOAD_BANDWIDTH2", 2000)) worker_upload_bandwidth_part2 = int( os.getenv("WORKER_UPLOAD_BANDWIDTH2", 2000)) tc_command = "sudo tc class change dev {} parent 1: classid 1:3 htb rate {}mbit ceil {}mbit && sudo tc class change dev {} parent 1: classid 1:4 htb rate {}mbit ceil {}mbit" ################################################################################ # training loop ################################################################################ for epoch in range(begin_epoch, num_epoch): tic = time.time() eval_metric.reset() nbatch = 0 data_iter = iter(train_data) end_of_batch = False next_data_batch = next(data_iter) while not end_of_batch: data_batch = next_data_batch if monitor is not None: monitor.tic() self.forward(data_batch, is_train=True) if int(os.getenv("TASK_LIMIT", 0)) == 1: ##first part bandwidth allocation ndarray.waitall() # self.logger.info("change bandwidth part1:, "+str(time.time())) x = str(ps_upload_bandwidth_part1) y = str(worker_upload_bandwidth_part1) cmd_up = tc_command.format("ens3", x, x, "ens3", y, y) cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x) os.system(cmd_up) # os.system(cmd_down) # self.logger.info("after forward, "+str(time.time())) self.backward() # self.logger.info("before update: "+str(time.time())) self.update() #异步执行的 if int(os.getenv("TASK_LIMIT", 0)) == 1: x = str(ps_upload_bandwidth_part2) y = str(worker_upload_bandwidth_part2) cmd_up = tc_command.format("ens3", x, x, "ens3", y, y) cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x) time.sleep(delay_time) ##second part bandwidth allocation # self.logger.info("change bandwidth part2:, "+str(time.time())) os.system(cmd_up) # os.system(cmd_down) try: # pre fetch next batch next_data_batch = next(data_iter) self.prepare(next_data_batch) except StopIteration: end_of_batch = True self.update_metric(eval_metric, data_batch.label) if monitor is not None: monitor.toc_print() if batch_end_callback is not None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) for callback in _as_list(batch_end_callback): callback(batch_end_params) nbatch += 1 # one epoch of training is finished for name, val in eval_metric.get_name_value(): self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) # sync aux params across devices arg_params, aux_params = self.get_params() self.set_params(arg_params, aux_params) if epoch_end_callback is not None: for callback in _as_list(epoch_end_callback): callback(epoch, self.symbol, arg_params, aux_params) #---------------------------------------- # evaluation on validation set if eval_data: res = self.score(eval_data, validation_metric, score_end_callback=eval_end_callback, batch_end_callback=eval_batch_end_callback, epoch=epoch) #TODO: pull this into default for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) # end of 1 epoch, reset the data-iter for another epoch train_data.reset()
def trainNet(net, trainer, train_data, loss, train_metric, epoch, config, logger, ctx): if not logger: assert False, 'require a logger' train_data.reset() # reset and re-shuffle if train_metric: train_metric.reset() trainer.set_learning_rate(config.TRAIN.lr * pow(config.TRAIN.decay_rate, epoch)) w = config.TRAIN.w batchsize = config.TRAIN.batchsize UseMetric = config.TRAIN.UseMetric seqLength = config.DATASET.seqLength nJoints = config.NETWORK.nJoints loss1, loss2, n1, n2 = [0] * len(ctx), [0] * len(ctx), 0.000001, 0.000001 RecordTime = {'load': 0, 'forward': 0, 'backward': 0, 'post': 0} for batch_i, batch in enumerate(train_data): beginT = time.time() data_list = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=1) label_list = gluon.utils.split_and_load( batch.label[0], ctx_list=ctx, batch_axis=1) # [[seqLength x 64 x 48] , 4] RecordTime['load'] += time.time() - beginT # forward beginT = time.time() Ls, Ls1, Ls2, output_list = [], [], [], [] with autograd.record(): for data, label, cx in zip(data_list, label_list, ctx): initial_state = [ nd.zeros(shape=(batchsize, config.NETWORK.hidden_dim), ctx=cx) for _ in range(2) ] start_token = nd.ones(shape=(batchsize, 3 * nJoints), ctx=cx) preds = net(data, initial_state, start_token) output_list.append(preds) # pred=[5, 64x48] L1, L2 = 0, 0 for pd, lb in zip(preds, label): L1 = L1 + loss(pd, lb) if seqLength > 1: for i in range(1, seqLength): deltaP = preds[i] - preds[i - 1] deltaG = label[i] - label[i - 1] L2 = L2 + loss(deltaP, deltaG) Ls1.append(L1) Ls2.append(L2) if seqLength > 1 else Ls2.append(nd.zeros(1)) Ls.append(L1 + w * L2) RecordTime['forward'] += time.time() - beginT # backward beginT = time.time() for L in Ls: L.backward() trainer.step(len(ctx) * batchsize) RecordTime['backward'] += time.time() - beginT beginT = time.time() # number n1 = n1 + len(ctx) * batchsize * seqLength n2 = n2 + len(ctx) * batchsize * (seqLength - 1) # loss for i in range(len(loss1)): loss1[i] += Ls1[i] loss2[i] += Ls2[i] # metric, save time if UseMetric: for pred_batch, label_batch in zip( output_list, label_list): # for each timestamp for t_pred, t_label in zip(pred_batch, label_batch): train_metric.update(t_label, t_pred) RecordTime['post'] += time.time() - beginT totalT = nd.array([RecordTime[k] for k in RecordTime]).sum().asscalar() for key in RecordTime: print("%-s: %.1fs %.1f%% " % (key, RecordTime[key], RecordTime[key] / totalT * 100), end=" ") print(" ") nd.waitall() loss1 = sum([item.sum().asscalar() for item in loss1]) loss2 = sum([item.sum().asscalar() for item in loss2]) TotalLoss = loss1 / n1 + w * loss2 / n2 MPJPE = train_metric.get()[-1].sum( axis=0).asscalar() / 17 if UseMetric else 0 logger.info( "TRAIN - Epoch:%2d LR:%.2e Loss1:%.2e Loss2(%2d):%.2e TotalLoss:%.2e MPJPE:%.1f" % (epoch + 1, trainer.learning_rate, loss1 / n1, w, loss2 / n2, TotalLoss, MPJPE)) if ((epoch + 1) % (config.end_epoch / 4) == 0 or epoch == 0): # save checkpoint saveModel(net, logger, config, isCKP=True, epoch=epoch + 1) if (epoch + 1 == config.end_epoch): # save final model saveModel(net, logger, config, isCKP=False, epoch=epoch + 1)
def test1(): ctx = mx.gpu() mx.random.seed(128) Cout = 256 kernel_max = 3 N, Cin, Height, Width = (128, 256, 112, 112) # mask = get_random_mask(Cout,Cin,(3,3),kernel_max).as_in_context(ctx) mask = nd.array([[[0, 1, 2]] * Cin] * Cout, dtype='float32', ctx=ctx) mask_ = nd.array([[0, 1, 2]] * Cin, dtype='float32', ctx=ctx) weight = nd.random.normal(0, 1e-2, (Cout, Cin, kernel_max)) amount = 10 inpt = nd.random.uniform(0, 1, (N, Cin, Height, Width), ctx=ctx) # custom original convolution net0 = MyConv(Cout, Cin, (1, kernel_max), kernel_max, mask, weight_initializer=mx.init.Constant(weight.expand_dims(2)), padding=(0, 1)) net0.initialize(ctx=ctx) tic = time.time() for _ in range(amount): out0 = net0(inpt) nd.waitall() # out0 = net0(inpt) timeuse0 = (time.time() - tic) / amount # out0 = out0[:, :, :-2, 1:-1] # FLK v1 # net = FLKConv(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask, weight_initializer=mx.init.Constant(weight)) # net.initialize(ctx=ctx) # tic = time.time() # for _ in range(amount): # out1 = net(inpt) # nd.waitall() # timeuse1 = (time.time() - tic) / amount # net2 = FLKConv_v2(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask, # weight_initializer=mx.init.Constant(weight)) # net2.initialize(ctx=ctx) # tic = time.time() # for _ in range(amount): # out2 = net2(inpt) # nd.waitall() # timeuse2 = (time.time() - tic) / amount # # net3 = FLKConv_v3(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask, # weight_initializer=mx.init.Constant(weight)) # net3.initialize(ctx=ctx) # tic = time.time() # for _ in range(amount): # out3 = net3(inpt) # nd.waitall() # timeuse3 = (time.time() - tic) / amount net4 = FLKConv_v4(Cout, Cin, (kernel_max, kernel_max), kernel_max, mask_, weight_initializer=mx.init.Constant(weight)) net4.initialize(ctx=ctx) tic = time.time() for _ in range(amount): out4 = net4(inpt) nd.waitall() timeuse4 = (time.time() - tic) / amount print('')
def train(args): np.random.seed(args.seed) if args.gpu: ctx = [mx.gpu(0)] else: ctx = [mx.cpu(0)] if args.dataset == "Sony": out_channels = 12 scale = 2 else: out_channels = 27 scale = 3 # load data train_transform = utils.Compose([ utils.RandomCrop(args.patch_size, scale), utils.RandomFlipLeftRight(), utils.RandomFlipTopBottom(), utils.RandomTranspose(), utils.ToTensor(), ]) train_dataset = data.MyDataset(args.dataset, "train", transform=train_transform) val_transform = utils.Compose([utils.ToTensor()]) val_dataset = data.MyDataset(args.dataset, "val", transform=val_transform) train_loader = gluon.data.DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, last_batch='rollover') val_loader = gluon.data.DataLoader(val_dataset, batch_size=1, last_batch='discard') unet = net.UNet(out_channels, scale) unet.initialize(init=initializer.Xavier(), ctx=ctx) # optimizer and loss trainer = gluon.Trainer(unet.collect_params(), 'adam', {'learning_rate': args.lr}) l1_loss = gluon.loss.L1Loss() print "Start training now.." for i in range(args.epochs): total_loss = 0 count = 0 profiler.set_state('run') for batch_id, (img, gt) in enumerate(train_loader): batch_size = img.shape[0] count += batch_size img_list = gluon.utils.split_and_load(img[0], ctx) gt_list = gluon.utils.split_and_load(gt[0], ctx) with autograd.record(): preds = [unet(x) for x in img_list] losses = [] for ii in range(len(preds)): loss = l1_loss(gt_list[ii], preds[ii]) losses.append(loss) for loss in losses: loss.backward() total_loss += sum([l.sum().asscalar() for l in losses]) avg_loss = total_loss / count trainer.step(batch_size) metric.update(gt_list, preds) F.waitall() profiler.set_state('stop') print profiler.dumps() break gt_save = gt_list[0] output_save = preds[0] if (batch_id + 1) % 100 == 0: message = "Epoch {}: [{}/{}]: l1_loss: {:.4f}".format( i + 1, count, len(train_dataset), avg_loss) print message temp = F.concat(gt_save, output_save, dim=3) temp = temp.asnumpy().reshape(temp.shape[2], temp.shape[3], 3) scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255, mode='RGB').save(args.save_model_dir + '%04d_%05d_00_train.jpg' % (i + 1, count)) # evaluate batches = 0 avg_psnr = 0. for img, gt in val_loader: batches += 1 imgs = gluon.utils.split_and_load(img[0], ctx) label = gluon.utils.split_and_load(gt[0], ctx) outputs = [] for x in imgs: outputs.append(unet(x)) metric.update(label, outputs) avg_psnr += 10 * math.log10(1 / metric.get()[1]) metric.reset() avg_psnr /= batches print('Epoch {}: validation avg psnr: {:.3f}'.format(i + 1, avg_psnr)) # save model if (i + 1) % args.save_freq == 0: save_model_filename = "Epoch_" + str(i + 1) + ".params" save_model_path = os.path.join(args.save_model_dir, save_model_filename) unet.save_params(save_model_path) print("\nCheckpoint, trained model saved at", save_model_path) # save model save_model_filename = "Final_Epoch_" + str(i + 1) + ".params" save_model_path = os.path.join(args.save_model_dir, save_model_filename) unet.save_params(save_model_path) print("\nCheckpoint, trained model saved at", save_model_path)
def train(num_gpus=GPU_COUNT, batch_size=batch_size, lr=0.001, train_epoch=train_epoch): #训练模型 train_iter, test_iter = train_loader, val_loader ctx = [mx.gpu(i) for i in range(num_gpus)] # print('running on:', ctx) logging.info('running on:GPU') # loss = gluon.loss.SoftmaxCrossEntropyLoss() best_acc = 0. # loss = L2Softmax(classes = 25,alpha = 10) trainer = gluon.Trainer(net.collect_params(), 'Adam', {'learning_rate': lr}) for epoch in range(train_epoch): train_loss = 0. start = time.time() for i, (X, y) in enumerate(train_iter): X = X / 255 gpu_Xs = gutils.split_and_load(X, ctx, even_split=False) gpu_ys = gutils.split_and_load(y, ctx, even_split=False) with autograd.record(): # ls = [loss(net(gpu_X), mx.nd.one_hot(gpu_y,classes)) # for gpu_X, gpu_y in zip(gpu_Xs, gpu_ys)] ls = [ loss(net(gpu_X), gpu_y) for gpu_X, gpu_y in zip(gpu_Xs, gpu_ys) ] for l in ls: l.backward() trainer.step(batch_size) ls_list = [nd.mean(i).asscalar() for i in ls] train_loss += sum(ls_list) / len(ls_list) if (i + 1) % 50 == 0: #打印训练日志 logging.info( "epoch[{epoch}] batch_num[{batch_num}] epochtrain_loss : {loss}" .format(epoch=epoch + 1, batch_num=i + 1, loss=train_loss / (i + 1))) # test_acc = evaluate_accuracy(test_iter, net, ctx) # train_time = time.time() - start # logging.info('epoch %d, time %.1f sec, test acc %.7f' % ( # epoch + 1, train_time, test_acc)) # net.save_parameters('weights/best_' + save_name + '.params') nd.waitall() train_time = time.time() - start test_acc = evaluate_accuracy(test_iter, net, ctx) logging.info('epoch[%d], time %.5f sec, test acc %.7f' % (epoch + 1, train_time, test_acc)) if test_acc > best_acc: best_acc = test_acc net.save_parameters('weights/best_' + save_name + '.params') if (epoch + 1) % 5 == 0: net.save_parameters('weights/gluon_' + save_name + str(epoch + 1) + '.params') net.save_parameters('weights/gluon_' + save_name + str(epoch + 1) + '.params') net.save_parameters('weights/last_' + save_name + '.params')
def main(): parser = argparse.ArgumentParser(description='Script to test the trained network on a game.') parser.add_argument('-r', '--rom', required=False, type=str, default=os.path.join('arena', 'games', 'roms', 'breakout.bin'), help='Path of the ROM File.') parser.add_argument('-v', '--visualization', required=False, type=int, default=0, help='Visualize the runs.') parser.add_argument('--lr', required=False, type=float, default=0.01, help='Learning rate of the AdaGrad optimizer') parser.add_argument('--eps', required=False, type=float, default=0.01, help='Eps of the AdaGrad optimizer') parser.add_argument('--clip-gradient', required=False, type=float, default=None, help='Clip threshold of the AdaGrad optimizer') parser.add_argument('--double-q', required=False, type=bool, default=False, help='Use Double DQN') parser.add_argument('--wd', required=False, type=float, default=0.0, help='Weight of the L2 Regularizer') parser.add_argument('-c', '--ctx', required=False, type=str, default='gpu', help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`') parser.add_argument('-d', '--dir-path', required=False, type=str, default='', help='Saving directory of model files.') parser.add_argument('--start-eps', required=False, type=float, default=1.0, help='Eps of the epsilon-greedy policy at the beginning') parser.add_argument('--replay-start-size', required=False, type=int, default=50000, help='The step that the training starts') parser.add_argument('--kvstore-update-period', required=False, type=int, default=1, help='The period that the worker updates the parameters from the sever') parser.add_argument('--kv-type', required=False, type=str, default=None, help='type of kvstore, default will not use kvstore, could also be dist_async') args, unknown = parser.parse_known_args() if args.dir_path == '': rom_name = os.path.splitext(os.path.basename(args.rom))[0] args.dir_path = 'dqn-%s' % rom_name ctx = re.findall('([a-z]+)(\d*)', args.ctx) ctx = [(device, int(num)) if len(num) >0 else (device, 0) for device, num in ctx] replay_start_size = args.replay_start_size max_start_nullops = 30 replay_memory_size = 1000000 history_length = 4 rows = 84 cols = 84 q_ctx = mx.Context(*ctx[0]) game = AtariGame(rom_path=args.rom, resize_mode='scale', replay_start_size=replay_start_size, resized_rows=rows, resized_cols=cols, max_null_op=max_start_nullops, replay_memory_size=replay_memory_size, display_screen=args.visualization, history_length=history_length) ##RUN NATURE freeze_interval = 10000 epoch_num = 200 steps_per_epoch = 250000 update_interval = 4 discount = 0.99 eps_start = args.start_eps eps_min = 0.1 eps_decay = (eps_start - 0.1) / 1000000 eps_curr = eps_start freeze_interval /= update_interval minibatch_size = 32 action_num = len(game.action_set) data_shapes = {'data': (minibatch_size, history_length) + (rows, cols), 'dqn_action': (minibatch_size,), 'dqn_reward': (minibatch_size,)} #optimizer = mx.optimizer.create(name='sgd', learning_rate=args.lr,wd=args.wd) optimizer = mx.optimizer.Nop() dqn_output_op = DQNOutputNpyOp() dqn_sym = dqn_sym_nature(action_num, dqn_output_op) qnet = Base(data_shapes=data_shapes, sym=dqn_sym, name='QNet', initializer=DQNInitializer(factor_type="in"), ctx=q_ctx) target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx) # Create kvstore testShape = (1,1686180*100) testParam = nd.ones(testShape,ctx=q_ctx) testGrad = nd.zeros(testShape,ctx=q_ctx) # Create kvstore if args.kv_type != None: kvType = args.kv_type kvStore = kvstore.create(kvType) #Initialize kvstore for idx,v in enumerate(qnet.params.values()): kvStore.init(idx,v); # Set optimizer on kvstore kvStore.set_optimizer(optimizer) kvstore_update_period = args.kvstore_update_period else: updater = mx.optimizer.get_updater(optimizer) # if args.kv_type != None: # kvType = args.kv_type # kvStore = kvstore.create(kvType) # kvStore.init(0,testParam) # testOptimizer = mx.optimizer.Nop() # kvStore.set_optimizer(testOptimizer) # kvstore_update_period = args.kvstore_update_period qnet.print_stat() target_qnet.print_stat() # Begin Playing Game training_steps = 0 total_steps = 0 while(1): time_before_wait = time.time() # kvStore.push(0,testGrad,priority=0) # kvStore.pull(0,testParam,priority=0) # testParam.wait_to_read() for paramIndex in range(len(qnet.params)):#range(6):# k=qnet.params.keys()[paramIndex] kvStore.push(paramIndex,qnet.params_grad[k],priority=-paramIndex) kvStore.pull(paramIndex,qnet.params[k],priority=-paramIndex) for v in qnet.params.values(): v.wait_to_read() logging.info("wait time %f" %(time.time()-time_before_wait)) for epoch in xrange(epoch_num): # Run Epoch steps_left = steps_per_epoch episode = 0 epoch_reward = 0 start = time.time() game.start() while steps_left > 0: # Running New Episode episode += 1 episode_loss = 0.0 episode_q_value = 0.0 episode_update_step = 0 episode_action_step = 0 time_episode_start = time.time() game.begin_episode(steps_left) while not game.episode_terminate: # 1. We need to choose a new action based on the current game status if game.state_enabled and game.replay_memory.sample_enabled: do_exploration = (npy_rng.rand() < eps_curr) eps_curr = max(eps_curr - eps_decay, eps_min) if do_exploration: action = npy_rng.randint(action_num) else: # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each # We can simply stack the current_state() of gaming instances and give prediction for all of them # We need to wait after calling calc_score(.), which makes the program slow # TODO Profiling the speed of this part! current_state = game.current_state() state = nd.array(current_state.reshape((1,) + current_state.shape), ctx=q_ctx) / float(255.0) qval_npy = qnet.forward(batch_size=1, data=state)[0].asnumpy() action = numpy.argmax(qval_npy) episode_q_value += qval_npy[0, action] episode_action_step += 1 else: action = npy_rng.randint(action_num) # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times) game.play(action) total_steps += 1 # 3. Update our Q network if we can start sampling from the replay memory # Also, we update every `update_interval` if total_steps % update_interval == 0 and game.replay_memory.sample_enabled: # 3.1 Draw sample from the replay_memory training_steps += 1 episode_update_step += 1 states, actions, rewards, next_states, terminate_flags \ = game.replay_memory.sample(batch_size=minibatch_size) states = nd.array(states, ctx=q_ctx) / float(255.0) next_states = nd.array(next_states, ctx=q_ctx) / float(255.0) actions = nd.array(actions, ctx=q_ctx) rewards = nd.array(rewards, ctx=q_ctx) terminate_flags = nd.array(terminate_flags, ctx=q_ctx) # 3.2 Use the target network to compute the scores and # get the corresponding target rewards if not args.double_q: target_qval = target_qnet.forward(batch_size=minibatch_size, data=next_states)[0] target_rewards = rewards + nd.choose_element_0index(target_qval, nd.argmax_channel(target_qval))\ * (1.0 - terminate_flags) * discount else: target_qval = target_qnet.forward(batch_size=minibatch_size, data=next_states)[0] qval = qnet.forward(batch_size=minibatch_size, data=next_states)[0] target_rewards = rewards + nd.choose_element_0index(target_qval, nd.argmax_channel(qval))\ * (1.0 - terminate_flags) * discount outputs = qnet.forward(batch_size=minibatch_size,is_train=True, data=states, dqn_action=actions, dqn_reward=target_rewards) qnet.backward(batch_size=minibatch_size) nd.waitall() time_before_update = time.time() if args.kv_type != None: if total_steps % kvstore_update_period == 0: update_to_kvstore(kvStore,qnet.params,qnet.params_grad) else: qnet.update(updater=updater) logging.info("update time %f" %(time.time()-time_before_update)) time_before_wait = time.time() nd.waitall() logging.info("wait time %f" %(time.time()-time_before_wait)) '''nd.waitall() time_before_wait = time.time() kvStore.push(0,testGrad,priority=0) kvStore.pull(0,testParam,priority=0) nd.waitall() logging.info("wait time %f" %(time.time()-time_before_wait))''' # 3.3 Calculate Loss diff = nd.abs(nd.choose_element_0index(outputs[0], actions) - target_rewards) quadratic_part = nd.clip(diff, -1, 1) loss = (0.5 * nd.sum(nd.square(quadratic_part)) + nd.sum(diff - quadratic_part)).asscalar() episode_loss += loss # 3.3 Update the target network every freeze_interval # (We can do annealing instead of hard copy) if training_steps % freeze_interval == 0: qnet.copy_params_to(target_qnet) steps_left -= game.episode_step time_episode_end = time.time() # Update the statistics epoch_reward += game.episode_reward info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \ % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward, game.episode_step / (time_episode_end - time_episode_start), eps_curr) if episode_update_step > 0: info_str += ", Avg Loss:%f/%d" % (episode_loss / episode_update_step, episode_update_step) if episode_action_step > 0: info_str += ", Avg Q Value:%f/%d" % (episode_q_value / episode_action_step, episode_action_step) logging.info(info_str) end = time.time() fps = steps_per_epoch / (end - start) qnet.save_params(dir_path=args.dir_path, epoch=epoch) logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d" % (epoch, fps, epoch_reward / float(episode), episode))
def test_RNN_class(typ="lstm", ret_typ="out"): num_hidden = (128, 128, 128) data_dim = 512 seq_len = 20 minibatch_size = 8 data_npy = numpy.random.standard_normal( (seq_len, minibatch_size, data_dim)) init_h_0 = numpy.random.standard_normal((minibatch_size, num_hidden[0])) init_h_1 = numpy.random.standard_normal((minibatch_size, num_hidden[1])) init_h_2 = numpy.random.standard_normal((minibatch_size, num_hidden[2])) if typ == "lstm": init_c_0 = numpy.random.standard_normal( (minibatch_size, num_hidden[0])) init_c_1 = numpy.random.standard_normal( (minibatch_size, num_hidden[1])) init_c_2 = numpy.random.standard_normal( (minibatch_size, num_hidden[2])) param_shapes = get_rnn_param_shapes(num_hidden=num_hidden, data_dim=data_dim, typ=typ) i2h_weight_npy = [ numpy.random.standard_normal(s) for s in param_shapes["i2h_weight"] ] i2h_bias_npy = [ numpy.random.standard_normal(s) / 100 for s in param_shapes["i2h_bias"] ] h2h_weight_npy = [ numpy.random.standard_normal(s) for s in param_shapes["h2h_weight"] ] h2h_bias_npy = [ numpy.random.standard_normal(s) / 100 for s in param_shapes["h2h_bias"] ] if ret_typ == "out": out_grad_npy = numpy.random.standard_normal( (seq_len, minibatch_size, num_hidden[2])) elif ret_typ == "state": mult = 2 if typ == "lstm" else 1 out_grad_npy = numpy.random.standard_normal( (minibatch_size, mult * (num_hidden[0] + num_hidden[1] + num_hidden[2]))) data = mx.sym.Variable("data") rnn = RNN(data_dim=data_dim, num_hidden=num_hidden, typ=typ, name=typ.upper()) rnn_cudnn = RNN(data_dim=data_dim, num_hidden=num_hidden, typ=typ, cudnn_opt=True, name=typ.upper() + "-cudnn") if ret_typ == "state": if typ == "lstm": rnn_out_h, rnn_out_c = rnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ) rnn_cudnn_out_h, rnn_cudnn_out_c = rnn_cudnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ) rnn_sym = mx.sym.Concat(*(rnn_out_h + rnn_out_c), num_args=len(num_hidden) * 2, dim=1) rnn_cudnn_sym = mx.sym.Concat(*(rnn_cudnn_out_h + rnn_cudnn_out_c), num_args=len(num_hidden) * 2, dim=1) else: rnn_out_h = rnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ) rnn_cudnn_out_h = rnn_cudnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ) rnn_sym = mx.sym.Concat(*(rnn_out_h), num_args=len(num_hidden), dim=1) rnn_cudnn_sym = mx.sym.Concat(*(rnn_cudnn_out_h), num_args=len(num_hidden), dim=1) else: rnn_out_h = rnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ) rnn_cudnn_out_h = rnn_cudnn.step(data=data, seq_len=seq_len, ret_typ=ret_typ) rnn_sym = rnn_out_h[-1] rnn_cudnn_sym = rnn_cudnn_out_h[-1] if typ == "lstm": rnn_exe = rnn_sym.simple_bind( ctx=mx.gpu(), **{ 'data': (seq_len, minibatch_size, data_dim), rnn.name + '->layer0:init_h': (minibatch_size, num_hidden[0]), rnn.name + '->layer0:init_c': (minibatch_size, num_hidden[0]), rnn.name + '->layer1:init_h': (minibatch_size, num_hidden[1]), rnn.name + '->layer1:init_c': (minibatch_size, num_hidden[1]), rnn.name + '->layer2:init_h': (minibatch_size, num_hidden[2]), rnn.name + '->layer2:init_c': (minibatch_size, num_hidden[2]) }) rnn_cudnn_exe = rnn_cudnn_sym.simple_bind( ctx=mx.gpu(), **{ 'data': (seq_len, minibatch_size, data_dim), rnn_cudnn.name + '->layer0:init_h': (minibatch_size, num_hidden[0]), rnn_cudnn.name + '->layer0:init_c': (minibatch_size, num_hidden[0]), rnn_cudnn.name + '->layer1:init_h': (minibatch_size, num_hidden[1]), rnn_cudnn.name + '->layer1:init_c': (minibatch_size, num_hidden[1]), rnn_cudnn.name + '->layer2:init_h': (minibatch_size, num_hidden[2]), rnn_cudnn.name + '->layer2:init_c': (minibatch_size, num_hidden[2]) }) else: rnn_exe = rnn_sym.simple_bind(ctx=mx.gpu(), **{ 'data': (seq_len, minibatch_size, data_dim), rnn.name + '->layer0:init_h': (minibatch_size, num_hidden[0]), rnn.name + '->layer1:init_h': (minibatch_size, num_hidden[1]), rnn.name + '->layer2:init_h': (minibatch_size, num_hidden[2]) }) rnn_cudnn_exe = rnn_cudnn_sym.simple_bind( ctx=mx.gpu(), **{ 'data': (seq_len, minibatch_size, data_dim), rnn_cudnn.name + '->layer0:init_h': (minibatch_size, num_hidden[0]), rnn_cudnn.name + '->layer1:init_h': (minibatch_size, num_hidden[1]), rnn_cudnn.name + '->layer2:init_h': (minibatch_size, num_hidden[2]) }) for i in range(len(num_hidden)): rnn_exe.arg_dict[rnn.name + '->layer%d:i2h_weight' % i][:] = i2h_weight_npy[i] rnn_exe.arg_dict[rnn.name + '->layer%d:h2h_weight' % i][:] = h2h_weight_npy[i] rnn_exe.arg_dict[rnn.name + '->layer%d:i2h_bias' % i][:] = i2h_bias_npy[i] rnn_exe.arg_dict[rnn.name + '->layer%d:h2h_bias' % i][:] = h2h_bias_npy[i] rnn_cudnn_exe.arg_dict[rnn_cudnn.name + '->layer%d:i2h_weight' % i][:] = i2h_weight_npy[i] rnn_cudnn_exe.arg_dict[rnn_cudnn.name + '->layer%d:h2h_weight' % i][:] = h2h_weight_npy[i] rnn_cudnn_exe.arg_dict[rnn_cudnn.name + '->layer%d:i2h_bias' % i][:] = i2h_bias_npy[i] rnn_cudnn_exe.arg_dict[rnn_cudnn.name + '->layer%d:h2h_bias' % i][:] = h2h_bias_npy[i] N = 1 if typ == "lstm": start = time.time() for j in range(N): rnn_outputs = rnn_exe.forward(is_train=True, **{ "data": data_npy, rnn.name + '->layer0:init_h': init_h_0, rnn.name + '->layer0:init_c': init_c_0, rnn.name + '->layer1:init_h': init_h_1, rnn.name + '->layer1:init_c': init_c_1, rnn.name + '->layer2:init_h': init_h_2, rnn.name + '->layer2:init_c': init_c_2 }) rnn_exe.backward( out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())]) nd.waitall() end = time.time() print("MXNet %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000)) start = time.time() for j in range(N): rnn_cudnn_outputs = rnn_cudnn_exe.forward( is_train=True, **{ "data": data_npy, rnn_cudnn.name + '->layer0:init_h': init_h_0, rnn_cudnn.name + '->layer0:init_c': init_c_0, rnn_cudnn.name + '->layer1:init_h': init_h_1, rnn_cudnn.name + '->layer1:init_c': init_c_1, rnn_cudnn.name + '->layer2:init_h': init_h_2, rnn_cudnn.name + '->layer2:init_c': init_c_2 }) rnn_cudnn_exe.backward( out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())]) nd.waitall() end = time.time() print("CuDNN %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000)) else: start = time.time() for j in range(N): rnn_outputs = rnn_exe.forward(is_train=True, **{ "data": data_npy, rnn.name + '->layer0:init_h': init_h_0, rnn.name + '->layer1:init_h': init_h_1, rnn.name + '->layer2:init_h': init_h_2 }) rnn_exe.backward( out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())]) nd.waitall() end = time.time() print("MXNet %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000)) start = time.time() for j in range(N): rnn_cudnn_outputs = rnn_cudnn_exe.forward( is_train=True, **{ "data": data_npy, rnn_cudnn.name + '->layer0:init_h': init_h_0, rnn_cudnn.name + '->layer1:init_h': init_h_1, rnn_cudnn.name + '->layer2:init_h': init_h_2 }) rnn_cudnn_exe.backward( out_grads=[mx.nd.array(out_grad_npy, ctx=mx.gpu())]) nd.waitall() end = time.time() print("CuDNN %s Time: %g ms" % (typ.upper(), (end - start) / N * 1000)) print( numpy.square(rnn_outputs[0].asnumpy() - rnn_cudnn_outputs[0].asnumpy()).mean()) for k, v in rnn_exe.grad_dict.items(): if k == 'data': #numpy.testing.assert_allclose(v.asnumpy(), rnn_cudnn_exe.grad_dict[k].asnumpy()) print(k, reldiff(v.asnumpy(), rnn_cudnn_exe.grad_dict[k].asnumpy())) else: postfix = k[k.find("->"):] #numpy.testing.assert_allclose(v.asnumpy(), rnn_cudnn_exe.grad_dict[rnn_cudnn.name + postfix].asnumpy()) print( k, reldiff( v.asnumpy(), rnn_cudnn_exe.grad_dict[rnn_cudnn.name + postfix].asnumpy()))
def train(cep, pool_size, epochs, train_data, val_data, ctx, netEn, netDe, netD, netD2, netDS, trainerEn, trainerDe, trainerD, trainerD2, trainerSD, lambda1, batch_size, expname, append=True, useAE=False): tp_file = open(expname + "_trainloss.txt", "w") tp_file.close() text_file = open(expname + "_validtest.txt", "w") text_file.close() #netGT, netDT, _, _ = set_test_network(opt.depth, ctx, opt.lr, opt.beta1,opt.ndf, opt.ngf, opt.append) GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() L1_loss = gluon.loss.L2Loss() image_pool = imagePool.ImagePool(pool_size) metric = mx.metric.CustomMetric(facc) metric2 = mx.metric.CustomMetric(facc) metricStrong = mx.metric.CustomMetric(facc) metricMSE = mx.metric.MSE() loss_rec_G = [] loss_rec_D = [] loss_rec_R = [] acc_rec = [] acc2_rec = [] loss_rec_D2 = [] loss_rec_G2 = [] lr = 2.0 * 512 stamp = datetime.now().strftime('%Y_%m_%d-%H_%M') logging.basicConfig(level=logging.DEBUG) if cep == -1: cep = 0 else: netEn.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_En.params', ctx=ctx) netDe.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_De.params', ctx=ctx) netD.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_D.params', ctx=ctx) netD2.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_D2.params', ctx=ctx) netDS.load_params('checkpoints/' + opt.expname + '_' + str(cep) + '_SD.params', ctx=ctx) iter = 0 for epoch in range(cep + 1, epochs): tic = time.time() btic = time.time() train_data.reset() #print('learning rate : '+str(trainerD.learning_rate )) for batch in train_data: ############################ # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z))) ########################### if ctx == mx.cpu(): ct = mx.cpu() else: ct = mx.gpu() real_in = batch.data[0] #.as_in_context(ctx) real_out = batch.data[1] #.as_in_context(ctx) if iter == 0: latent_shape = (batch_size, 512, 1, 1) #code.shape out_l_shape = (batch_size, 1, 1, 1) #netD2((code)).shape out_i_shape = (batch_size, 1, 1, 1) #netD(netDe(code)).shape out_s_shape = (batch_size, 1, 1, 1) #netSD(netDe(code)).shape real_in = gluon.utils.split_and_load(real_in, ctx) real_out = gluon.utils.split_and_load(real_out, ctx) fake_latent = [netEn(r) for r in real_in] real_latent = nd.random.uniform(low=-1, high=1, shape=latent_shape) real_latent = gluon.utils.split_and_load(real_latent, ctx) fake_out = [netDe(f) for f in fake_latent] fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out eps2 = nd.random.uniform(low=-1, high=1, shape=latent_shape, ctx=ct) eps2 = gluon.utils.split_and_load(eps2, ctx) if epoch > 150: # (1/float(batch_size))*512*150:# and epoch%10==0: print('Mining..') mu = nd.random.uniform(low=-1, high=1, shape=latent_shape, ctx=ct) #isigma = nd.ones((batch_size,64,1,1),ctx=ctx)*0.000001 mu.attach_grad() #sigma.attach_grad() images = netDe(mu) fake_img1T = nd.concat(images[0], images[1], images[2], dim=1) fake_img2T = nd.concat(images[3], images[4], images[5], dim=1) fake_img3T = nd.concat(images[6], images[7], images[8], dim=1) fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2) visual.visualize(fake_img) plt.savefig('outputs/' + expname + '_fakespre_' + str(epoch) + '.png') eps2 = gluon.utils.split_and_load(mu, ctx) for e in eps2: e.attach_grad() for ep2 in range(1): with autograd.record(): #eps = nd.random_normal(loc=0, scale=1, shape=fake_latent.shape, ctx=ctx) # #eps2 = gluon.utils.split_and_load(nd.tanh(mu),ctx) #+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx) rec_output = [netDS(netDe(e)) for e in eps2] fake_label = gluon.utils.split_and_load( nd.zeros(out_s_shape), ctx) errGS = [ GAN_loss(r, f) for r, f in zip(rec_output, fake_label) ] for e in errGS: e.backward() for idx, _ in enumerate(eps2): eps2[idx] = nd.tanh(eps2[idx] - lr / eps2[idx].shape[0] * eps2[idx].grad) images = netDe((eps2[0])) fake_img1T = nd.concat(images[0], images[1], images[2], dim=1) fake_img2T = nd.concat(images[3], images[4], images[5], dim=1) fake_img3T = nd.concat(images[6], images[7], images[8], dim=1) fake_img = nd.concat(fake_img1T, fake_img2T, fake_img3T, dim=2) visual.visualize(fake_img) plt.savefig('outputs/' + expname + str(ep2) + '_fakespost_' + str(epoch) + '.png') #eps2 = nd.tanh(mu)#+nd.multiply(eps,sigma))#nd.random.uniform( low=-1, high=1, shape=fake_latent.shape, ctx=ctx) with autograd.record(): #eps2 = gluon.utils.split_and_load(eps2,ctx) # Train with fake image # Use image pooling to utilize history imagesi output = [netD(f) for f in fake_concat] output2 = [netD2(f) for f in fake_latent] fake_label = nd.zeros(out_i_shape) fake_label = gluon.utils.split_and_load(fake_label, ctx) fake_latent_label = nd.zeros(out_l_shape) fake_latent_label = gluon.utils.split_and_load( fake_latent_label, ctx) eps = gluon.utils.split_and_load( nd.random.uniform(low=-1, high=1, shape=latent_shape), ctx) rec_output = [netD(netDe(e)) for e in eps] errD_fake = [ GAN_loss(r, f) for r, f in zip(rec_output, fake_label) ] errD_fake2 = [ GAN_loss(o, f) for o, f in zip(output, fake_label) ] errD2_fake = [ GAN_loss(o, f) for o, f in zip(output2, fake_latent_label) ] for f, o in zip(fake_label, rec_output): metric.update([ f, ], [ o, ]) for f, o in zip(fake_latent_label, output2): metric2.update([ f, ], [ o, ]) real_concat = nd.concat(real_in, real_out, dim=1) if append else real_out output = [netD(r) for r in real_concat] output2 = [netD2(r) for r in real_latent] real_label = gluon.utils.split_and_load( nd.ones(out_i_shape), ctx) real_latent_label = gluon.utils.split_and_load( nd.ones(out_l_shape), ctx) errD_real = [ GAN_loss(o, r) for o, r in zip(output, real_label) ] errD2_real = [ GAN_loss(o, r) for o, r in zip(output2, real_latent_label) ] for e1, e2, e4, e5 in zip(errD_real, errD_fake, errD2_real, errD2_fake): err = (e1 + e2) * 0.5 + (e5 + e4) * 0.5 err.backward() for f, o in zip(real_label, output): metric.update([ f, ], [ o, ]) for f, o in zip(real_latent_label, output2): metric2.update([ f, ], [ o, ]) trainerD.step(batch.data[0].shape[0]) trainerD2.step(batch.data[0].shape[0]) nd.waitall() with autograd.record(): strong_output = [netDS(netDe(e)) for e in eps] strong_real = [netDS(f) for f in fake_concat] errs1 = [ GAN_loss(r, f) for r, f in zip(strong_output, fake_label) ] errs2 = [ GAN_loss(r, f) for r, f in zip(strong_real, real_label) ] for f, s in zip(fake_label, strong_output): metricStrong.update([ f, ], [ s, ]) for f, s in zip(real_label, strong_real): metricStrong.update([ f, ], [ s, ]) for e1, e2 in zip(errs1, errs2): strongerr = 0.5 * (e1 + e2) strongerr.backward() trainerSD.step(batch.data[0].shape[0]) nd.waitall() ############################ # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z)) ########################### with autograd.record(): sh = out_l_shape #eps2 = nd.random_normal(loc=0, scale=1, shape=noiseshape, ctx=ctx) # #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx) #if epoch>100: # eps2 = nd.multiply(eps2,sigma)+mu # eps2 = nd.tanh(eps2) #else: #eps = nd.random.uniform( low=-1, high=1, shape=noiseshape, ctx=ctx) #eps2 = nd.concat(eps,eps2,dim=0) rec_output = [netD(netDe(e)) for e in eps2] fake_latent = [(netEn(r)) for r in real_in] output2 = [netD2(f) for f in fake_latent] fake_out = [netDe(f) for f in fake_latent] fake_concat = nd.concat(real_in, fake_out, dim=1) if append else fake_out output = [netD(f) for f in fake_concat] real_label = gluon.utils.split_and_load( nd.ones(out_i_shape), ctx) real_latent_label = gluon.utils.split_and_load( nd.ones(out_l_shape), ctx) errG2 = [ GAN_loss(r, f) for r, f in zip(rec_output, real_label) ] errR = [ L1_loss(r, f) * lambda1 for r, f in zip(real_out, fake_out) ] errG = [ 10 * GAN_loss(r, f) for r, f in zip(output2, real_latent_label) ] # +errG2+errR for e1, e2, e3 in zip(errG, errG2, errR): e = e1 + e2 + e3 e.backward() trainerDe.step(batch.data[0].shape[0]) trainerEn.step(batch.data[0].shape[0]) nd.waitall() errD = (errD_real[0] + errD_fake[0]) * 0.5 errD2 = (errD2_real[0] + errD2_fake[0]) * 0.5 loss_rec_G2.append(nd.mean(errG2[0]).asscalar()) loss_rec_G.append( nd.mean(nd.mean(errG[0])).asscalar() - nd.mean(errG2[0]).asscalar() - nd.mean(errR[0]).asscalar()) loss_rec_D.append(nd.mean(errD[0]).asscalar()) loss_rec_R.append(nd.mean(errR[0]).asscalar()) loss_rec_D2.append(nd.mean(errD2[0]).asscalar()) _, acc2 = metric2.get() name, acc = metric.get() acc_rec.append(acc) acc2_rec.append(acc2) # Print log infomation every ten batches if iter % 10 == 0: _, acc2 = metric2.get() name, acc = metric.get() _, accStrong = metricStrong.get() logging.info('speed: {} samples/s'.format( batch_size / (time.time() - btic))) #print(errD) #logging.info('discriminator loss = %f, D2 loss = %f, generator loss = %f, G2 loss = %f, SD loss = %f, D acc = %f , D2 acc = %f, DS acc = %f, reconstruction error= %f at iter %d epoch %d' # % (nd.mean(errD[0]).asscalar(),nd.mean(errD2[0]).asscalar(), # nd.mean(errG[0]-errG2[0]-errR[0]).asscalar(),nd.mean(errG2[0]).asscalar(),nd.mean(strongerr[0]).asscalar() ,acc,acc2,accStrong[0],nd.mean(errR[0]).asscalar() ,iter, epoch)) iter = iter + 1 btic = time.time() name, acc = metric.get() _, acc2 = metric2.get() #tp_file = open(expname + "_trainloss.txt", "a") #tp_file.write(str(nd.mean(errG2).asscalar()) + " " + str( # nd.mean(nd.mean(errG)).asscalar() - nd.mean(errG2).asscalar() - nd.mean(errR).asscalar()) + " " + str( # nd.mean(errD).asscalar()) + " " + str(nd.mean(errD2).asscalar()) + " " + str(nd.mean(errR).asscalar()) +" "+str(acc) + " " + str(acc2)+"\n") #tp_file.close() metric.reset() metric2.reset() train_data.reset() metricStrong.reset() logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc)) logging.info('time: %f' % (time.time() - tic)) if epoch % 2 == 0: # and epoch>0: text_file = open(expname + "_validtest.txt", "a") filename = "checkpoints/" + expname + "_" + str( epoch) + "_D.params" netD.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_D2.params" netD2.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_En.params" netEn.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_De.params" netDe.save_parameters(filename) filename = "checkpoints/" + expname + "_" + str( epoch) + "_SD.params" netDS.save_parameters(filename) fake_img1 = nd.concat(real_in[0], real_out[0], fake_out[0], dim=1) fake_img2 = nd.concat(real_in[1], real_out[1], fake_out[1], dim=1) fake_img3 = nd.concat(real_in[2], real_out[2], fake_out[2], dim=1) fake_img4 = nd.concat(real_in[3], real_out[3], fake_out[3], dim=1) val_data.reset() text_file = open(expname + "_validtest.txt", "a") for vbatch in val_data: real_in = vbatch.data[0] real_out = vbatch.data[1] real_in = gluon.utils.split_and_load(real_in, ctx) real_out = gluon.utils.split_and_load(real_out, ctx) fake_latent = [netEn(r) for r in real_in] fake_out = [netDe(f) for f in fake_latent] for f, r in zip(fake_out, real_out): metricMSE.update([ f, ], [ r, ]) _, acc2 = metricMSE.get() toterrR = 0 for e in errR: toterrR += nd.mean(e).asscalar() text_file.write("%s %s %s\n" % (str(epoch), toterrR, str(acc2))) metricMSE.reset() return ([ loss_rec_D, loss_rec_G, loss_rec_R, acc_rec, loss_rec_D2, loss_rec_G2, acc2_rec ])
def train(opt): sw = SummaryWriter(logdir='./logs', flush_secs=5) decay_every = int(opt.n_epoch / 2) if opt.experiment is None: opt.experiment = 'samples' os.system('mkdir {}'.format(opt.experiment)) if opt.gpu_ids == '-1': context = [mx.cpu()] else: #context = mx.gpu(7) context = [mx.gpu(int(i)) for i in opt.gpu_ids.split(',') if i.strip()] print("context: {}".format(context)) features = load_vgg_model_features(ctx_list=context, last_layer=28) ##### Prapare data for training or validation ##### dataset = DataSet( opt.dataroot, RandomCrop(opt.fineSize), transforms.Resize(int(opt.fineSize / 4), interpolation=3), transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) dataloader = DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers), last_batch='rollover') ##### Build Network ##### netG = SRGenerator() netG.initialize(ctx=context[0]) netD = SRDiscriminator() netD.initialize(ctx=context[0]) # Enforce non-deferred initialization by one forward pass computation dummy_in = nd.random.uniform( 0, 1, (1, 3, int(opt.fineSize / 4), int(opt.fineSize / 4)), ctx=context[0]) netD(netG(dummy_in)) # Our own re-setting on parameters weights_init(netG.collect_params()) netG.collect_params().reset_ctx(context) weights_init(netD.collect_params()) netD.collect_params().reset_ctx(context) optimizer_G = gluon.Trainer(params=netG.collect_params(), optimizer='adam', optimizer_params={ 'learning_rate': opt.lr_init, 'beta1': opt.beta1 }, kvstore='local') optimizer_D = gluon.Trainer(params=netD.collect_params(), optimizer='adam', optimizer_params={ 'learning_rate': opt.lr_init, 'beta1': opt.beta1 }, kvstore='local') ##### Stage 1/2 of Training Process ##### # Pre-train Generator G to avoid undesired local optima when training SRGAN. print("Start pre-train Generator ...") param_file = os.path.join(opt.experiment, 'netG_init_epoch.param') if os.path.exists(param_file): print( "Load existed parameter file pre-trained: {}, skip the pre-train process." .format(param_file)) netG.load_parameters(param_file, ctx=context) else: print("No existed parameter file, keep going to pre-train.") for epoch in range(opt.n_epoch_init): start = time.time() batch = 0 for hr_img_iter, lr_img_iter in dataloader: #hr_img = hr_img.as_in_context(context) #lr_img = lr_img.as_in_context(context) hr_imgs = gluon.utils.split_and_load(hr_img_iter, ctx_list=context) lr_imgs = gluon.utils.split_and_load(lr_img_iter, ctx_list=context) with autograd.record(): ls = [ mse_loss(hr_img, netG(lr_img)) for hr_img, lr_img in zip(hr_imgs, lr_imgs) ] for l in ls: l.backward() # with autograd.record(): # hr_img_pred = netG(mx.nd.array(lr_img)) # loss = mse_loss(mx.nd.array(hr_img), hr_img_predit) # autograd.backward(loss) optimizer_G.step(opt.batchSize) print("Epoch %d: Batch %d: mse: %.8f" % (epoch, batch, ls[-1].mean().asscalar())) batch += opt.batchSize nd.waitall() train_time = time.time() - start print("Epoch %d: mse: %.8f trainning time:%.1f sec" % (epoch, ls[-1].mean().asscalar(), train_time)) if epoch % 20 == 0: netG.save_parameters('{0}/netG_init_epoch_{1}.param'.format( opt.experiment, epoch)) if epoch == opt.n_epoch_init - 1: netG.save_parameters('{0}/netG_init_epoch.param'.format( opt.experiment)) print("Pre-train Generator finished ...") ##### Stage 2/2 of Training Process ##### # Jointly optimize G and D, namely train SRGAN. print("Start to train SRGAN ...") mean_mask = nd.zeros((opt.batchSize, 3, opt.fineSize, opt.fineSize), ctx=context[0]) mean_mask[:, 0, :, :] = 0.485 mean_mask[:, 1, :, :] = 0.456 mean_mask[:, 2, :, :] = 0.406 std_mask = nd.zeros((opt.batchSize, 3, opt.fineSize, opt.fineSize), ctx=context[0]) std_mask[:, 0, :, :] = 0.229 std_mask[:, 1, :, :] = 0.224 std_mask[:, 2, :, :] = 0.225 real_label = nd.ones((opt.batchSize, ), ctx=context[0]) fake_label = nd.zeros((opt.batchSize, ), ctx=context[0]) mean_masks = mx.gluon.utils.split_and_load(mean_mask, ctx_list=context) std_masks = mx.gluon.utils.split_and_load(std_mask, ctx_list=context) real_labels = mx.gluon.utils.split_and_load(real_label, ctx_list=context) fake_labels = mx.gluon.utils.split_and_load(fake_label, ctx_list=context) loss_d = gluon.loss.SigmoidBinaryCrossEntropyLoss() losses_log = loss_dict() for epoch in range(0, opt.n_epoch): start = time.time() batch = 0 train_errD = 0 train_errG = 0 for hr_img_iter, lr_img_iter in dataloader: losses_log.reset() hr_imgs = gluon.utils.split_and_load(hr_img_iter, ctx_list=context) lr_imgs = gluon.utils.split_and_load(lr_img_iter, ctx_list=context) hr_fake_imgs = [] # Step1. Optimize D # Step2. Optimize G batch_errD = [] batch_errG = [] print("Optimize D in a Batch...") with autograd.record(): for hr_img, lr_img, mean_mask, std_mask, real_label, fake_label in zip( hr_imgs, lr_imgs, mean_masks, std_masks, real_labels, fake_labels): # errD computation output = netD(hr_img).reshape((-1, 1)) errD_real = loss_d(output, real_label) hr_img_fake = netG(lr_img) hr_fake_imgs.append(hr_img_fake) output = netD(hr_img_fake.detach()).reshape((-1, 1)) errD_fake = loss_d(output, fake_label) errD = errD_real + errD_fake batch_errD.append(errD) losses_log.add(lr_img=lr_img, hr_img=hr_img, hr_img_fake=hr_img_fake) # run backward on batch errD and update parameters autograd.backward(batch_errD) optimizer_D.step(opt.batchSize) print("Optimize G in a Batch...") with autograd.record(): for hr_img, lr_img, hr_img_fake, mean_mask, std_mask, real_label, fake_label in zip( hr_imgs, lr_imgs, hr_fake_imgs, mean_masks, std_masks, real_labels, fake_labels): # errG computation errM = mse_loss(hr_img_fake, hr_img) input_fake = ((hr_img_fake + 1) / 2 - mean_mask) / std_mask fake_emb = vgg_feature(input_fake, features) input_real = ((hr_img + 1) / 2 - mean_mask) / std_mask real_emb = vgg_feature(input_real, features) errV = 6e-3 * mse_loss(fake_emb, real_emb) output = netD(hr_img_fake).reshape((-1, 1)) errA = 1e-3 * loss_d(output, real_label) errG = errM + errV + errA batch_errG.append(errG) # run backward on batch errG and update parameters autograd.backward(batch_errG) # for errG in batch_errG: # errG.backward() # losses_log.add(errG=errG, errM=errM, errV=errV, errA=errA) optimizer_G.step(opt.batchSize) # sum losses over all devices train_errD += sum([errD.sum().asscalar() for errD in batch_errD]) train_errG += sum([errG.sum().asscalar() for errG in batch_errG]) print( "Epoch:%d, Batch:%d ----- D-Loss = %.3f, G-Loss = %.3f (Time %.1f sec)" % (epoch, batch * opt.batchSize, train_errD, train_errG, time.time() - start)) batch += 1 plot_loss(sw, losses_log, epoch * len(dataloader) + batch, epoch, batch) if epoch != 0 and (epoch % decay_every == 0): optimizer_G.set_learning_rate(optimizer_G.learning_rate * opt.lr_decay) optimizer_D.set_learning_rate(optimizer_D.learning_rate * opt.lr_decay) if (epoch != 0) and (epoch % 10 == 0): plot_img(sw, losses_log) netG.save_parameters('{0}/netG_epoch_{1}.param'.format( opt.experiment, epoch)) netD.save_parameters('{0}/netD_epoch_{1}.param'.format( opt.experiment, epoch)) print("Train SRGAN finished ...")