def unfold_from(dynamic_axes_like): #def unfold_from(initial_state, dynamic_axes_like): # create a new dynamic axis if a length increase is specified out_axis = dynamic_axes_like if length_increase != 1: factors = sequence.broadcast_as(length_increase, out_axis) # repeat each frame 'length_increase' times, on average out_axis = sequence.where(factors) # note: values are irrelevant; only the newly created axis matters state_fwd = ForwardDeclaration(name='unfold_state_fwd') prev_state = sequence.delay(state_fwd, initial_state=initial_state, name='unfold_prev_state') # TODO: must allow multiple variables, just like recurrence, as to allow beam decoding (permutation matrix) z = generator_function(prev_state) # returns either (output) or (output, new state) output = z.outputs[0] new_state = z.outputs[1] if len(z.outputs) > 1 else output # we allow generator to return a single value if it is identical to the new state # apply map_state_function if given new_state = map_state_function(new_state) # implant the dynamic axis (from dynamic_axes_like) from cntk.internal import sanitize_input, typemap from ..cntk_py import reconcile_dynamic_axis new_state = typemap(reconcile_dynamic_axis)(sanitize_input(new_state), sanitize_input(out_axis)) new_state = combine([new_state], name='unfold_new_state') state_fwd.resolve_to(new_state) output = combine([output], name='unfold_output') # BUGBUG: without this, it crashes with bad weak ptr # BUGBUG: MUST do this after resolving the recurrence, otherwise also crashes # apply until_predicate if given if until_predicate is not None: valid_frames = Recurrence(lambda h, x: (1-past_value(x)) * h, initial_state=1, name='valid_frames')(until_predicate(output)) output = sequence.gather(output, valid_frames, name='valid_output') return output
def unfold_from(initial_state, dynamic_axes_like): # create a new dynamic axis if a length increase is specified out_axis = dynamic_axes_like if length_increase != 1: factors = sequence.broadcast_as(length_increase, out_axis) # repeat each frame 'length_increase' times, on average out_axis = sequence.where(factors) # note: values are irrelevant; only the newly created axis matters state_fwd = ForwardDeclaration(name='unfold_state_fwd') prev_state = sequence.delay(state_fwd, initial_state=initial_state, name='unfold_prev_state') # TODO: must allow multiple variables, just like recurrence, as to allow beam decoding (permutation matrix) z = generator_function(prev_state) # returns either (output) or (output, new state) output = z.outputs[0] new_state = z.outputs[1] if len(z.outputs) > 1 else output # we allow generator to return a single value if it is identical to the new state # implant the dynamic axis (from dynamic_axes_like) from cntk.internal import sanitize_input, typemap new_state = typemap(reconcile_dynamic_axes)(sanitize_input(new_state), sanitize_input(out_axis)) new_state = combine([new_state], name='unfold_new_state') state_fwd.resolve_to(new_state) output = combine([output], name='unfold_output') # BUGBUG: without this, it crashes with bad weak ptr # BUGBUG: MUST do this after resolving the recurrence, otherwise also crashes # apply until_predicate if given if until_predicate is not None: valid_frames = Recurrence(lambda h, x: (1-sequence.past_value(x)) * h, initial_state=1, name='valid_frames')(until_predicate(output)) output = sequence.gather(output, valid_frames, name='valid_output') return output