def make_net(ims, samples, pr, reuse = True, train = True): if pr.net_type == 'i3d': import i3d_kinetics keep_prob = 0.5 if train else 1. if pr.use_i3d_logits: with tf.variable_scope('RGB', reuse = reuse): net = tfu.normalize_ims(ims) i3d_net = i3d_kinetics.InceptionI3d(pr.num_classes, spatial_squeeze = True, final_endpoint = 'Logits') logits, _ = i3d_net(net, is_training = train, dropout_keep_prob = keep_prob) return ut.Struct(logits = logits, prob = tf.nn.softmax(logits), last_conv = logits) else: with tf.variable_scope('RGB', reuse = reuse): i3d_net = i3d_kinetics.InceptionI3d(pr.num_classes, final_endpoint = 'Mixed_5c') net = tfu.normalize_ims(ims) net, _ = i3d_net(net, is_training = train, dropout_keep_prob = keep_prob) last_conv = net net = tf.reduce_mean(last_conv, [1, 2, 3], keep_dims = True) with slim.arg_scope(shift.arg_scope(pr, reuse = reuse, train = train)): logits = shift.conv3d( net, pr.num_classes, [1, 1, 1], scope = 'lb/logits', activation_fn = None, normalizer_fn = None)[:, 0, 0, 0, :] return ut.Struct(logits = logits, prob = tf.nn.softmax(logits), last_conv = net) elif pr.net_type == 'shift': with slim.arg_scope(shift.arg_scope(pr, reuse = reuse, train = train)): # To train the network without audio, you can set samples to be an all-zero array, and # set pr.use_sound = False. shift_net = shift.make_net(ims, samples, pr, reuse = reuse, train = train) if pr.use_dropout: shift_net.last_conv = slim.dropout(shift_net.last_conv, is_training = train) net = shift_net.last_conv net = tf.reduce_mean(net, [1, 2, 3], keep_dims = True) logits = shift.conv3d( net, pr.num_classes, [1, 1, 1], scope = 'lb/logits', activation_fn = None, normalizer_fn = None)[:, 0, 0, 0, :] return ut.Struct(logits = logits, prob = tf.nn.softmax(logits), last_conv = net) elif pr.net_type == 'c3d': import c3d with slim.arg_scope(shift.arg_scope(reuse = reuse, train = train)): net = c3d.make_net(ims, samples, pr, reuse = reuse, train = train) net = net.last_conv net = tf.reduce_mean(net, [1, 2, 3], keep_dims = True) logits = c3d.conv3d( net, pr.num_classes, [1, 1, 1], scope = 'lb/logits', activation_fn = None, normalizer_fn = None)[:, 0, 0, 0, :] return ut.Struct(logits = logits, prob = tf.nn.softmax(logits), last_conv = net) else: raise RuntimeError()
def make_net(ims, sfs, spec, phase, pr, reuse=True, train=True, vid_net_full=None): if pr.mono: print 'Using mono!' sfs = make_mono(sfs, tile=True) if vid_net_full is None: if pr.net_style == 'static': n = shape(ims, 1) if 0: ims_tile = tf.tile(ims[:, n / 2:n / 2 + 1], (1, n, 1, 1, 1)) else: ims = tf.cast(ims, tf.float32) ims_tile = tf.tile(ims[:, n / 2:n / 2 + 1], (1, n, 1, 1, 1)) vid_net_full = shift_net.make_net(ims_tile, sfs, pr, None, reuse, train) elif pr.net_style == 'no-im': vid_net_full = None elif pr.net_style == 'full': vid_net_full = shift_net.make_net(ims, sfs, pr, None, reuse, train) elif pr.net_style == 'i3d': with tf.variable_scope('RGB', reuse=reuse): import sep_i3d i3d_net = sep_i3d.InceptionI3d(1) vid_net_full = ut.Struct( scales=i3d_net(ims, is_training=train)) with slim.arg_scope(unet_arg_scope(pr, reuse=reuse, train=train)): # test acts = [] def conv(*args, **kwargs): out = conv2d(*args, activation_fn=None, **kwargs) acts.append(out) out = mu.lrelu(out, 0.2) return out def deconv(*args, **kwargs): args = list(args) if kwargs.get('do_pop', True): skip_layer = acts.pop() else: skip_layer = acts[-1] if 'do_pop' in kwargs: del kwargs['do_pop'] x = args[0] if kwargs.get('concat', True): x = tf.concat([x, skip_layer], 3) if 'concat' in kwargs: del kwargs['concat'] args[0] = tf.nn.relu(x) return deconv2d(*args, activation_fn=None, **kwargs) def merge_level(net, n): if vid_net_full is None: return net vid_net = tf.reduce_mean(vid_net_full.scales[n], [2, 3], keep_dims=True) vid_net = vid_net[:, :, 0, :, :] s = shape(vid_net) if shape(net, 1) != s[1]: vid_net = tf.image.resize_images(vid_net, [shape(net, 1), 1]) print 'Video net before merge:', s, 'After:', shape(vid_net) else: print 'No need to resize:', s, shape(net) vid_net = tf.tile(vid_net, (1, 1, shape(net, 2), 1)) net = tf.concat([net, vid_net], 3) acts[-1] = net return net num_freq = shape(spec, 2) net = tf.concat([ ed(normalize_spec(spec, pr), 3), ed(normalize_phase(phase, pr), 3) ], 3) net = net[:, :, :pr.freq_len, :] net = conv(net, 64, 4, scope='gen/conv1', stride=[1, 2]) net = conv(net, 128, 4, scope='gen/conv2', stride=[1, 2]) net = conv(net, 256, 4, scope='gen/conv3', stride=2) net = merge_level(net, 0) net = conv(net, 512, 4, scope='gen/conv4', stride=2) net = merge_level(net, 1) net = conv(net, 512, 4, scope='gen/conv5', stride=2) net = merge_level(net, 2) net = conv(net, 512, 4, scope='gen/conv6', stride=2) net = conv(net, 512, 4, scope='gen/conv7', stride=2) net = conv(net, 512, 4, scope='gen/conv8', stride=2) net = conv(net, 512, 4, scope='gen/conv9', stride=2) net = deconv(net, 512, 4, scope='gen/deconv1', stride=2, concat=False) net = deconv(net, 512, 4, scope='gen/deconv2', stride=2) net = deconv(net, 512, 4, scope='gen/deconv3', stride=2) net = deconv(net, 512, 4, scope='gen/deconv4', stride=2) net = deconv(net, 512, 4, scope='gen/deconv5', stride=2) net = deconv(net, 256, 4, scope='gen/deconv6', stride=2) net = deconv(net, 128, 4, scope='gen/deconv7', stride=2) net = deconv(net, 64, 4, scope='gen/deconv8', stride=[1, 2]) out_fg = deconv(net, 2, 4, scope='gen/fg', stride=[1, 2], normalizer_fn=None, do_pop=False) out_bg = deconv(net, 2, 4, scope='gen/bg', stride=[1, 2], normalizer_fn=None, do_pop=False) def process(out): pred_spec = out[..., 0] pred_spec = tf.tanh(pred_spec) pred_spec = unnormalize_spec(pred_spec, pr) pred_phase = out[..., 1] pred_phase = tf.tanh(pred_phase) pred_phase = unnormalize_phase(pred_phase, pr) val = soundrep.db_from_amp(0.) if pr.log_spec else 0. pred_spec = tf.pad(pred_spec, [(0, 0), (0, 0), (0, num_freq - shape(pred_spec, 2))], constant_values=val) if pr.phase_type == 'pred': pred_phase = tf.concat([pred_phase, phase[..., -1:]], 2) elif pr.phase_type == 'orig': pred_phase = phase else: raise RuntimeError() # if ut.hastrue(pr, 'griffin_lim'): # print 'using griffin-lim' # pred_wav = griffin_lim(pred_spec, pred_phase, pr) # else: pred_wav = istft(pred_spec, pred_phase, pr) return pred_spec, pred_phase, pred_wav pred_spec_fg, pred_phase_fg, pred_wav_fg = process(out_fg) pred_spec_bg, pred_phase_bg, pred_wav_bg = process(out_bg) return ut.Struct( pred_spec_fg=pred_spec_fg, pred_wav_fg=pred_wav_fg, pred_phase_fg=pred_phase_fg, pred_spec_bg=pred_spec_bg, pred_phase_bg=pred_phase_bg, pred_wav_bg=pred_wav_bg, vid_net=vid_net_full, )