def test_onehot(): runtime = get_runtime() param = ng.parameter([3], dtype=np.int32) model = ng.one_hot(param, 3, 1, 0, 0) computation = runtime.computation(model, param) expected = np.eye(3)[np.array([1, 0, 2])] input_data = np.array([1, 0, 2], dtype=np.int32) result = computation(input_data) assert np.allclose(result, expected)
def SparseSoftmaxCrossEntropyWithLogits(self, tf_node, inputs): """ Computes softmax cross entropy. The inputs `logits` are unscaled log probabilities, and each row of `labels[i]` must be a valid distribution. Reference: https://goo.gl/z5T2my Arguments: tf_node: NodeDef object, the tensorflow node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the tensorflow node. Inputs to tf_node: logits, labels, name """ # logits: (N1, Y1), labels: (N2,) logits, labels = inputs # check input dimension try: assert len(logits.axes) == 2 assert len(labels.axes) == 1 assert logits.axes[0].length == labels.axes[0].length except: raise NotImplementedError("logits' shape must be (Y, N), " "labels' shape must be (N,), " "other shapes not supported yet.") # get axis axis_y = logits.axes[1] # labels_one_hot: (Y2, N2) labels_one_hot = ng.one_hot(labels, axis=axis_y) # predicts: (N1, Y1) predicts = ng.softmax(logits, normalization_axes=axis_y) # dim-shuffle / cast to (Y1, N1) predicts_axes = ng.make_axes( [axis for axis in reversed(predicts.axes)]) predicts = ng.axes_with_order(predicts, axes=predicts_axes) labels_one_hot = ng.cast_axes(labels_one_hot, predicts_axes) # cross_entropy: (N1,) cross_entropy = ng.cross_entropy_multi(predicts, labels_one_hot, out_axes=(logits.axes[0], )) return cross_entropy
def LabelCrossEntropy(self, c2_op, inputs): """ Computes the cross entropy between the input and the label set. Arguments: c2_op: OperatorDef object, the caffe2 node to convert. inputs: List of ngraph Ops as inputs to this node. Returns: A ngraph Op corresponding to the caffe2 node. """ y, labels = inputs labels_one_hot = ng.one_hot(labels, axis=y.axes[1]) labels_one_hot = ng.cast_axes(labels_one_hot, [labels_one_hot.axes[0], y.axes[0]]) return ng.cross_entropy_multi(y, labels_one_hot, out_axes=y.axes[0])
def run_resnet_benchmark(dataset, num_iterations, n_skip, batch_size, device_id, transformer_type, device, bprop=True, batch_norm=False, visualize=False, stage_depth=1): inputs, data, train_set = get_fake_data(dataset, batch_size, num_iterations) # Running forward propagation model_out = get_mini_resnet(inputs, dataset, device, device_id, batch_norm=batch_norm, stage_depth=stage_depth) # Running back propagation if bprop: with ng.metadata(device=device, device_id=device_id, parallel=ax.N): optimizer = GradientDescentMomentum(0.01, 0.9) train_loss = ng.cross_entropy_multi( model_out, ng.one_hot(inputs['label'], axis=ax.Y)) batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) batch_cost_computation_op = ng.computation(batch_cost, "all") benchmark = Benchmark(batch_cost_computation_op, train_set, inputs, transformer_type, device) Benchmark.print_benchmark_results( benchmark.time(num_iterations, n_skip, dataset + '_msra_bprop', visualize, 'device_id')) else: fprop_computation_op = ng.computation(model_out, 'all') benchmark = Benchmark(fprop_computation_op, train_set, inputs, transformer_type, device) Benchmark.print_benchmark_results( benchmark.time(num_iterations, n_skip, dataset + '_msra_fprop', visualize))
def one_hot_comparison(hot_axes, axes, C): """ TODO. Arguments: hot_axes: TODO axes: TODO """ u = rng.random_integers(0, C.length - 1, axes, dtype=np.int8) u_p = ng.placeholder(axes, dtype=u.dtype) v = np.zeros(hot_axes.lengths, dtype=np.float32) udxiter = np.nditer(u, flags=['multi_index']) for uiter in udxiter: vindex = [int(uiter)] vindex.extend(udxiter.multi_index) v[tuple(vindex)] = 1 v_t = executor(ng.one_hot(u_p, axis=C), u_p)(u) np.testing.assert_allclose(v_t, v)
def run_cifar_benchmark(n_iter=10, n_skip=5, batch_size=4, transformer_type='cpu'): inputs, data, train_set = get_fake_cifar(batch_size, n_iter) model = get_mini_resnet(inputs) optimizer = GradientDescentMomentum(0.01, 0.9) train_loss = ng.cross_entropy_multi(model(inputs['image']), ng.one_hot(inputs['label'], axis=ax.Y)) batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) batch_cost_computation_op = ng.computation(batch_cost, "all") feed_dict = fill_feed_dict(train_set, inputs) benchmarks = dict() benchmarks['cifar_msra_fprop'] = run_benchmark(batch_cost_computation_op, transformer_type, feed_dict, n_skip, n_iter) print_benchmark_results(benchmarks)
def sparse_softmax_cross_entropy_with_logits(labels=None, logits=None, name=None): """ Computes softmax cross entropy. The inputs `logits` are unscaled log probabilities, and each row of `labels[i]` must be a valid distribution. Args: labels: of axis (N,) for (POS_0,) logits: of axis (N, Y) for (POS_1, POS_0) name: name of the ngraph op """ # Check input dimension # ( N, Y), ( N) # logits: (pos_1, pos_0), labels: (pos_0) try: assert len(logits.axes) == 2 assert len(labels.axes) == 1 assert logits.axes[0].length == labels.axes[0].length except: raise NotImplementedError("logits' shape must be (N, Y), " "labels' shape must be (N,), " "other shapes not supported yet.") # get axis axis_n, axis_y = logits.axes # convert labels to one-hot labels labels = ng.cast_axes(labels, ng.make_axes(axis_n)) labels = ng.one_hot(labels, axis=axis_y) labels = ng.axes_with_order(labels, axes=logits.axes) # predicts: (N, Y) predicts = ng.softmax(logits, normalization_axes=axis_y) # cross_entropy: (N) res = ng.cross_entropy_multi(predicts, labels, out_axes=(axis_n, )) return cast_to_pos_axes(res).named(name)
def run_resnet_benchmark(dataset, n_iter, n_skip, batch_size, device_id, transformer_type, device, bprop=False, visualize=False): inputs, data, train_set = get_fake_data(dataset, batch_size, n_iter) model_out = get_mini_resnet(inputs, dataset, device_id) # Running forward propagation fprop_computation_op = ng.computation(model_out, 'all') benchmark_fprop = Benchmark(fprop_computation_op, train_set, inputs, transformer_type, device) Benchmark.print_benchmark_results(benchmark_fprop.time(n_iter, n_skip, dataset + '_msra_fprop', visualize)) # Running back propagation if bprop: optimizer = GradientDescentMomentum(0.01, 0.9) train_loss = ng.cross_entropy_multi(model_out, ng.one_hot(inputs['label'], axis=ax.Y)) batch_cost = ng.sequential([optimizer(train_loss), ng.mean(train_loss, out_axes=())]) batch_cost_computation_op = ng.computation(batch_cost, "all") benchmark = Benchmark(batch_cost_computation_op, train_set, inputs, transformer_type, device) Benchmark.print_benchmark_results(benchmark.time(n_iter, n_skip, dataset + '_msra_bprop', visualize))
def cifar_mean_subtract(x): bgr_mean = ng.persistent_tensor( axes=x.axes.find_by_name('C'), initial_value=np.array([104., 119., 127.])) return (x - bgr_mean) / 255. seq1 = Sequential([Preprocess(functor=cifar_mean_subtract), Affine(nout=200, weight_init=UniformInit(-0.1, 0.1), activation=Rectlin()), Affine(axes=ax.Y, weight_init=UniformInit(-0.1, 0.1), activation=Softmax())]) optimizer = GradientDescentMomentum(0.1, 0.9) train_prob = seq1(inputs['image']) train_loss = ng.cross_entropy_multi(train_prob, ng.one_hot(inputs['label'], axis=ax.Y)) batch_cost = ng.sequential([optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_outputs = dict(batch_cost=batch_cost) with Layer.inference_mode_on(): inference_prob = seq1(inputs['image']) errors = ng.not_equal(ng.argmax(inference_prob, out_axes=[ax.N]), inputs['label']) eval_loss = ng.cross_entropy_multi(inference_prob, ng.one_hot(inputs['label'], axis=ax.Y)) eval_outputs = dict(cross_ent_loss=eval_loss, misclass_pct=errors) # Now bind the computations we are interested in with closing(ngt.make_transformer()) as transformer: train_computation = make_bound_computation(transformer, train_outputs, inputs) loss_computation = make_bound_computation(transformer, eval_outputs, inputs) cbs = make_default_callbacks(output_file=args.output_file,
def train_mnist_mlp(transformer_name, data_dir=None, rng_seed=12, batch_size=128, train_iter=10, eval_iter=10): assert transformer_name in ['cpu', 'hetr'] assert isinstance(rng_seed, int) # Apply this metadata to graph regardless of transformer, # but it is ignored for non-HeTr case hetr_device_ids = (0, 1) # use consistent rng seed between runs np.random.seed(rng_seed) # Data train_data, valid_data = MNIST(path=data_dir).load_data() train_set = ArrayIterator(train_data, batch_size, total_iterations=train_iter) valid_set = ArrayIterator(valid_data, batch_size) inputs = train_set.make_placeholders() ax.Y.length = 10 # Model with ng.metadata(device_id=hetr_device_ids, parallel=ax.N): seq1 = Sequential([ Preprocess(functor=lambda x: x / 255.), Affine(nout=100, weight_init=GaussianInit(), activation=Rectlin()), Affine(axes=ax.Y, weight_init=GaussianInit(), activation=Logistic()) ]) train_prob = seq1(inputs['image']) train_loss = ng.cross_entropy_binary( train_prob, ng.one_hot(inputs['label'], axis=ax.Y)) optimizer = GradientDescentMomentum(0.1, 0.9) batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_outputs = dict(batch_cost=batch_cost) with Layer.inference_mode_on(): inference_prob = seq1(inputs['image']) errors = ng.not_equal(ng.argmax(inference_prob, out_axes=[ax.N]), inputs['label']) eval_loss = ng.cross_entropy_binary( inference_prob, ng.one_hot(inputs['label'], axis=ax.Y)) eval_outputs = dict(cross_ent_loss=eval_loss, misclass_pct=errors) # Runtime with closing( ngt.make_transformer_factory(transformer_name)()) as transformer: train_computation = make_bound_computation(transformer, train_outputs, inputs) loss_computation = make_bound_computation(transformer, eval_outputs, inputs) train_costs = list() for step in range(train_iter): out = train_computation(next(train_set)) train_costs.append(float(out['batch_cost'])) ce_loss = list() for step in range(eval_iter): out = loss_computation(next(valid_set)) ce_loss.append(np.mean(out['cross_ent_loss'])) return train_costs, ce_loss
valid_set = ArrayIterator(valid_data, args.batch_size) inputs = train_set.make_placeholders() ax.Y.length = 10 ###################### # Model specification seq1 = Sequential([Preprocess(functor=lambda x: x / 255.), Affine(nout=100, weight_init=GaussianInit(), activation=Rectlin()), Affine(axes=ax.Y, weight_init=GaussianInit(), activation=Logistic())]) optimizer = GradientDescentMomentum(0.1, 0.9) output_prob = seq1.train_outputs(inputs['image']) errors = ng.not_equal(ng.argmax(output_prob, out_axes=[ax.N]), inputs['label']) loss = ng.cross_entropy_binary(output_prob, ng.one_hot(inputs['label'], axis=ax.Y)) mean_cost = ng.mean(loss, out_axes=()) updates = optimizer(loss) train_outputs = dict(batch_cost=mean_cost, updates=updates) loss_outputs = dict(cross_ent_loss=loss, misclass_pct=errors) # Now bind the computations we are interested in transformer = ngt.make_transformer() train_computation = make_bound_computation(transformer, train_outputs, inputs) loss_computation = make_bound_computation(transformer, loss_outputs, inputs) cbs = make_default_callbacks(output_file=args.output_file, frequency=args.iter_interval, train_computation=train_computation, total_iterations=args.num_iterations,
###################### # Input specification ax.C.length, ax.H.length, ax.W.length = train_set.shapes['image'] ax.D.length = 1 ax.N.length = args.batch_size ax.Y.length = 10 # placeholders with descriptive names inputs = dict(image=ng.placeholder([ax.C, ax.H, ax.W, ax.N]), label=ng.placeholder([ax.N])) optimizer = GradientDescentMomentum(0.01, 0.9) output_prob = seq1.train_outputs(inputs['image']) errors = ng.not_equal(ng.argmax(output_prob, out_axes=[ax.N]), inputs['label']) loss = ng.cross_entropy_multi(output_prob, ng.one_hot(inputs['label'], axis=ax.Y)) mean_cost = ng.mean(loss, out_axes=()) updates = optimizer(loss) train_outputs = dict(batch_cost=mean_cost, updates=updates) loss_outputs = dict(cross_ent_loss=loss, misclass_pct=errors) # Now bind the computations we are interested in transformer = ngt.make_transformer() train_computation = make_bound_computation(transformer, train_outputs, inputs) loss_computation = make_bound_computation(transformer, loss_outputs, inputs) cbs = make_default_callbacks(output_file=args.output_file, frequency=args.iter_interval, train_computation=train_computation, total_iterations=args.num_iterations,
ax.Y.length = len(tree_bank_data.vocab) def expand_onehot(x): # Assign roles x.axes.find_by_short_name('time')[0].add_role(ar.time) x.axes.find_by_short_name('time')[0].is_recurrent = True return ng.one_hot(x, axis=ax.Y) # weight initialization init = UniformInit(low=-0.08, high=0.08) if args.use_lut: layer_0 = LookupTable(50, 100, init, update=True, pad_idx=0) else: layer_0 = Preprocess(functor=lambda x: ng.one_hot(x, axis=ax.Y)) if args.layer_type == "rnn": rlayer = Recurrent(hidden_size, init, activation=Tanh()) elif args.layer_type == "birnn": rlayer = BiRNN(hidden_size, init, activation=Tanh(), return_sequence=True, sum_out=True) if args.use_lut: layer_0 = LookupTable(50, 100, init, update=False) else: layer_0 = Preprocess(functor=expand_onehot) # model initialization seq1 = Sequential([layer_0, rlayer, Affine(init, activation=Softmax(), bias_init=init, axes=(ax.Y,))])
def train_network(model, train_set, valid_set, batch_size, epochs, log_file): ''' Trains the predefined network. Trains the model and saves the progress in the log file that is defined in the arguments model(object): Defines the model in Neon train_set(object): Defines the training set valid_set(object): Defines the validation set args(object): Training arguments batch_size(int): Minibatch size epochs(int): Number of training epoch log_file(string): File name to store trainig logs for plotting ''' # Form placeholders for inputs to the network # Iterations needed for learning rate schedule inputs = train_set.make_placeholders(include_iteration=True) # Convert labels into one-hot vectors one_hot_label = ng.one_hot(inputs['label'], axis=ax.Y) learning_rate_policy = { 'name': 'schedule', 'schedule': list(np.arange(2, epochs, 2)), 'gamma': 0.6, 'base_lr': 0.001 } optimizer = GradientDescentMomentum(learning_rate=learning_rate_policy, momentum_coef=0.9, wdecay=0.005, iteration=inputs['iteration']) # Define graph for training train_prob = model(inputs['video']) train_loss = ng.cross_entropy_multi(train_prob, one_hot_label) batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) with closing(ngt.make_transformer()) as transformer: # Define graph for calculating validation set error and misclassification rate # Use inference mode for validation to avoid dropout in forward pass with Layer.inference_mode_on(): inference_prob = model(inputs['video']) errors = ng.not_equal(ng.argmax(inference_prob), inputs['label']) eval_loss = ng.cross_entropy_multi(inference_prob, one_hot_label) eval_outputs = {'cross_ent_loss': eval_loss, 'misclass': errors} eval_computation = make_bound_computation(transformer, eval_outputs, inputs) train_outputs = {'batch_cost': batch_cost} train_computation = make_bound_computation(transformer, train_outputs, inputs) interval_cost = 0.0 # Train in epochs logs = {'train': [], 'validation': [], 'misclass': []} for epoch in trange(epochs, desc='Epochs'): # Setup the training bar numBatches = train_set.ndata // batch_size tpbar = tqdm(unit='batches', ncols=100, total=numBatches, leave=False) train_set.reset() valid_set.reset() train_log = [] for step, data in enumerate(train_set): data = dict(data) data['iteration'] = epoch # learning schedule based on epochs output = train_computation(data) train_log.append(float(output['batch_cost'])) tpbar.update(1) tpbar.set_description("Training {:0.4f}".format( float(output['batch_cost']))) interval_cost += float(output['batch_cost']) tqdm.write("Epoch {epch} complete. " "Avg Train Cost {cost:0.4f}".format(epch=epoch, cost=interval_cost / step)) interval_cost = 0.0 tpbar.close() validation_loss = run_validation(valid_set, eval_computation) tqdm.write("Avg losses: {}".format(validation_loss)) logs['train'].append(train_log) logs['validation'].append(validation_loss['cross_ent_loss']) logs['misclass'].append(validation_loss['misclass']) # Save log data and plot at the end of each epoch with open(log_file, 'wb') as f: pickle.dump(logs, f) plot_logs(logs=logs)
inputs = train_set.make_placeholders() ax.Y.length = len(tree_bank_data.vocab) def expand_onehot(x): return ng.one_hot(x, axis=ax.Y) # weight initialization init = UniformInit(low=-0.08, high=0.08) if args.use_lut: layer_0 = LookupTable(50, 100, init, update=True, pad_idx=0) else: layer_0 = Preprocess(functor=lambda x: ng.one_hot(x, axis=ax.Y)) if args.layer_type == "rnn": rlayer = Recurrent(hidden_size, init, activation=Tanh()) elif args.layer_type == "birnn": rlayer = BiRNN(hidden_size, init, activation=Tanh(), return_sequence=True, sum_out=True) # model initialization seq1 = Sequential([ layer_0, rlayer, Affine(init, activation=Softmax(), bias_init=init, axes=(ax.Y, )) ])
def expand_onehot(x): """ Simply converts an integer to a one-hot vector of the same size as out_axis """ return ng.one_hot(x, axis=out_axis)
def expand_onehot(x): # Assign roles x.axes.find_by_short_name('time')[0].add_role(ar.time) x.axes.find_by_short_name('time')[0].is_recurrent = True return ng.one_hot(x, axis=ax.Y)
no_steps = 75 step = num_iterations // no_steps schedule = list(np.arange(step, num_iterations, step)) learning_rate_policy = { 'name': 'schedule', 'schedule': schedule, 'gamma': 0.95, 'base_lr': 0.01 } optimizer = GradientDescentMomentum(learning_rate=learning_rate_policy, iteration=inputs['iteration']) # Define the loss function (Cross entropy loss) # Note that we convert the integer values of input['y'] to one hot here fwd_prop = seq1(inputs['X']) train_loss = ng.cross_entropy_multi(fwd_prop, ng.one_hot(inputs['y'], axis=out_axis), usebits=True) # Train cost computation batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_computation = ng.computation([batch_cost, fwd_prop], "all") train_outputs = dict(batch_cost=batch_cost) # Forward prop of evaluation set # Required for correct functioning of batch norm and dropout layers during inference mode with Layer.inference_mode_on(): inference_prop = seq1(inputs['X']) eval_loss = ng.cross_entropy_multi(inference_prop, ng.one_hot(inputs['y'], axis=out_axis),
# Optimizer # Provided learning policy takes learning rate as input to graph using a placeholder. # This allows you to control learning rate based on various factors of network learning_rate_policy = {'name': 'provided', 'lr_placeholder': lr_ph} optimizer = GradientDescentMomentum(learning_rate=learning_rate_policy, momentum_coef=momentum_coef, wdecay=wdecay, nesterov=False, iteration=input_ops_train['iteration']) # Make a prediction prediction = resnet(input_ops_train['image']) # Calculate loss train_loss = ng.cross_entropy_multi( prediction, ng.one_hot(input_ops_train['label'], axis=ax.Y)) # Average loss over the batch batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_computation = ng.computation(batch_cost, "all") # Instantiate the Saver object to save weights weight_saver = Saver() with ng.metadata(device=device_hetr, device_id=device_id, parallel=ax.N): # Inference with Layer.inference_mode_on(): # Doing inference inference_prob = resnet(input_ops_valid['image']) eval_loss = ng.cross_entropy_multi(
ax.Y.length = 10 resnet = residual_network(args.stage_depth) learning_rate_policy = {'name': 'schedule', 'schedule': [32000, 48000], 'gamma': 0.1, 'base_lr': 0.1} optimizer = GradientDescentMomentum(learning_rate=learning_rate_policy, momentum_coef=0.9, wdecay=0.0001, iteration=inputs['iteration']) label_indices = inputs['label'] train_loss = ng.cross_entropy_multi(resnet(inputs['image']), ng.one_hot(label_indices, axis=ax.Y)) batch_cost = ng.sequential([optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_computation = ng.computation(batch_cost, "all") with Layer.inference_mode_on(): inference_prob = resnet(inputs['image']) errors = ng.not_equal(ng.argmax(inference_prob, out_axes=[ax.N]), label_indices) eval_loss = ng.cross_entropy_multi(inference_prob, ng.one_hot(label_indices, axis=ax.Y)) eval_loss_names = ['cross_ent_loss', 'misclass'] eval_computation = ng.computation([eval_loss, errors], "all") # Now bind the computations we are interested in transformer = ngt.make_transformer() train_function = transformer.add_computation(train_computation) eval_function = transformer.add_computation(eval_computation)
init, activation=Tanh(), gate_activation=Logistic(), return_sequence=True) # model initialization seq1 = Sequential([ Preprocess(functor=expand_onehot), rlayer1, rlayer2, Affine(init, activation=Softmax(), bias_init=init, axes=(ax.Y, )) ]) optimizer = RMSProp(gradient_clip_value=gradient_clip_value) train_prob = seq1(inputs['inp_txt']) train_loss = ng.cross_entropy_multi(train_prob, ng.one_hot(inputs['tgt_txt'], axis=ax.Y), usebits=True) batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_outputs = dict(batch_cost=batch_cost) with Layer.inference_mode_on(): inference_prob = seq1(inputs['inp_txt']) errors = ng.not_equal(ng.argmax(inference_prob, reduction_axes=[ax.Y]), inputs['tgt_txt']) errors_last_char = ng.slice_along_axis(errors, ax.REC, time_steps - 1) eval_loss = ng.cross_entropy_multi(inference_prob, ng.one_hot(inputs['tgt_txt'], axis=ax.Y),
# download penn treebank # set shift_target to be False, since it is going to predict the same sequence tree_bank_data = PTB(path=args.data_dir, shift_target=False) ptb_data = tree_bank_data.load_data() train_set = SequentialArrayIterator(ptb_data['train'], batch_size=args.batch_size, time_steps=time_steps, total_iterations=args.num_iterations, reverse_target=True, get_prev_target=True) # weight initialization init = UniformInit(low=-0.08, high=0.08) # model initialization one_hot_enc = Preprocess(functor=lambda x: ng.one_hot(x, axis=ax.Y)) enc = Recurrent(hidden_size, init, activation=Tanh(), reset_cells=True, return_sequence=False) one_hot_dec = Preprocess(functor=lambda x: ng.one_hot(x, axis=ax.Y)) dec = Recurrent(hidden_size, init, activation=Tanh(), reset_cells=True, return_sequence=True) linear = Affine(init, activation=Softmax(), bias_init=init, axes=(ax.Y, ax.REC))
emb_enc_inputs = ng.dot(W_emb, inputs['inp_txt']) # decoder input embedding emb_dec_input = [] ax.N.length = args.batch_size for i in range(ax.N.length): # for each iteration, permute (by true label) # encoder input embedding for teacher forcing input to decoder emb_enc_input = ng.slice_along_axis(emb_enc_inputs, axis=ax.N, idx=i) tmp_axis_1 = ng.make_axis(length=time_steps, name='tmp_axis_1') emb_enc_input_tmp = ng.cast_axes(emb_enc_input, ng.make_axes([hidden_feature_axis, tmp_axis_1])) perm = ng.slice_along_axis(inputs['tgt_txt'], axis=ax.N, idx=i) one_hot_target_tmp = ng.one_hot(perm, axis=tmp_axis_1) emb_dec_input.append(ng.dot(emb_enc_input_tmp, one_hot_target_tmp)) emb_dec_inputs = ng.stack(emb_dec_input, axis=ax.N, pos=1) enc_input = emb_enc_inputs dec_input = emb_dec_inputs else: enc_input = inputs['inp_txt'] dec_input = inputs['teacher_txt'] (enc_h_out, enc_c_out) = enc(enc_input, return_cell_state=True) # compute the last hidden/cell states as decoder's initial states rec_axis = enc_h_out.axes.recurrent_axis()
resnet = residual_network(args.stage_depth) learning_rate_policy = { 'name': 'schedule', 'schedule': [32000, 48000], 'gamma': 0.1, 'base_lr': 0.1 } optimizer = GradientDescentMomentum(learning_rate=learning_rate_policy, momentum_coef=0.9, wdecay=0.0001, iteration=inputs['iteration']) label_indices = inputs['label'] train_loss = ng.cross_entropy_multi(resnet(inputs['image']), ng.one_hot(label_indices, axis=ax.Y)) batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_computation = ng.computation(batch_cost, "all") with Layer.inference_mode_on(): inference_prob = resnet(inputs['image']) errors = ng.not_equal(ng.argmax(inference_prob, out_axes=[ax.N]), label_indices) eval_loss = ng.cross_entropy_multi(inference_prob, ng.one_hot(label_indices, axis=ax.Y)) eval_loss_names = ['cross_ent_loss', 'misclass'] eval_computation = ng.computation([eval_loss, errors], "all") # Now bind the computations we are interested in
def expand_onehot(x): return ng.one_hot(x, axis=ax.Y)
logits1 = ng.cast_axes(logits_concat[0], [ax.Y, N]) logits2 = ng.cast_axes(logits_concat[1], [ax.Y, N]) # Compute loss function label1 = ng.slice_along_axis( inputs['answer'], axis=inputs['answer'].axes.feature_axes()[0], idx=0) label2 = ng.slice_along_axis( inputs['answer'], axis=inputs['answer'].axes.feature_axes()[0], idx=1) labels_concat = [label1, label2] loss1 = ng.cross_entropy_multi(logits1, ng.one_hot(label1, axis=ax.Y), usebits=False) loss2 = ng.cross_entropy_multi(logits2, ng.one_hot(label2, axis=ax.Y), usebits=False) # Total Loss train_loss = loss1 + loss2 # Set optimizer (no learning rate scheduler used) optimizer = Adam(learning_rate=2e-3) print('compiling the graph') # Cost set up batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())])
resnet = residual_network(args.stage_depth) learning_rate_policy = { 'name': 'schedule', 'schedule': [32000, 48000], 'gamma': 0.1, 'base_lr': 0.1 } optimizer = GradientDescentMomentum(learning_rate=learning_rate_policy, momentum_coef=0.9, wdecay=0.0001, iteration=inputs['iteration']) label_indices = inputs['label'] train_loss = ng.cross_entropy_multi(resnet(inputs['image']), ng.one_hot(label_indices, axis=ax.Y)) batch_cost = ng.sequential( [optimizer(train_loss), ng.mean(train_loss, out_axes=())]) train_computation = ng.computation(batch_cost, "all") with Layer.inference_mode_on(): inference_prob = resnet(inputs['image']) errors = ng.not_equal(ng.argmax(inference_prob, out_axes=[ax.N]), label_indices) eval_loss = ng.cross_entropy_multi( inference_prob, ng.one_hot(label_indices, axis=ax.Y)) eval_loss_names = ['cross_ent_loss', 'misclass'] eval_computation = ng.computation([eval_loss, errors], "all") # Now bind the computations we are interested in
train_set = ArrayIterator(wikimovies.data_dict['train'], batch_size=args.batch_size, total_iterations=num_iterations) test_set = ArrayIterator(wikimovies.data_dict['test'], batch_size=args.batch_size) inputs = train_set.make_placeholders() vocab_axis = ng.make_axis(length=wikimovies.vocab_size, name='vocab_axis') memn2n = KVMemN2N(num_iterations, args.batch_size, args.emb_size, args.nhops, wikimovies.story_length, wikimovies.memory_size, wikimovies.vocab_size, vocab_axis, args.use_v_luts) # Compute answer predictions a_pred, _ = memn2n(inputs) loss = ng.cross_entropy_multi(a_pred, ng.one_hot(inputs['answer'], axis=vocab_axis), usebits=True) mean_cost = ng.sum(loss, out_axes=[]) optimizer = Adam(learning_rate=args.lr) updates = optimizer(loss) batch_cost = ng.sequential([updates, mean_cost]) # provide outputs for bound computation train_outputs = dict(batch_cost=batch_cost, train_preds=a_pred) with Layer.inference_mode_on(): a_pred_inference, _ = memn2n(inputs)
def expand_onehot(x): # Assign the recurrent role and property to the axis named 'time' x.axes.find_by_short_name('time')[0].add_role(ar.time) x.axes.find_by_short_name('time')[0].is_recurrent = True return ng.one_hot(x, axis=ax.Y)
# Logits logits1 = ng.cast_axes(logits_concat[0], [ax.Y, N]) logits2 = ng.cast_axes(logits_concat[1], [ax.Y, N]) # Compute loss function label1 = ng.slice_along_axis(inputs['answer'], axis=inputs['answer'].axes.feature_axes()[0], idx=0) label2 = ng.slice_along_axis(inputs['answer'], axis=inputs['answer'].axes.feature_axes()[0], idx=1) labels_concat = [label1, label2] loss1 = ng.cross_entropy_multi(logits1, ng.one_hot(label1, axis=ax.Y), usebits=False) loss2 = ng.cross_entropy_multi(logits2, ng.one_hot(label2, axis=ax.Y), usebits=False) # Total Loss train_loss = loss1 + loss2 # Set optimizer (no learning rate scheduler used) optimizer = Adam(learning_rate=2e-3) print('compiling the graph') # Cost set up batch_cost = ng.sequential(