Ejemplo n.º 1
0
def make_net(ims, samples, pr, reuse = True, train = True):
  if pr.net_type == 'i3d':
    import i3d_kinetics
    keep_prob = 0.5 if train else 1.
    if pr.use_i3d_logits:
      with tf.variable_scope('RGB', reuse = reuse):
        net = tfu.normalize_ims(ims)
        i3d_net = i3d_kinetics.InceptionI3d(pr.num_classes, spatial_squeeze = True, final_endpoint = 'Logits')
        logits, _ = i3d_net(net, is_training = train, dropout_keep_prob = keep_prob)
        return ut.Struct(logits = logits, prob = tf.nn.softmax(logits), last_conv = logits)
    else:
      with tf.variable_scope('RGB', reuse = reuse):
        i3d_net = i3d_kinetics.InceptionI3d(pr.num_classes, final_endpoint = 'Mixed_5c')
        net = tfu.normalize_ims(ims)
        net, _ = i3d_net(net, is_training = train, dropout_keep_prob = keep_prob)
      last_conv = net
      net = tf.reduce_mean(last_conv, [1, 2, 3], keep_dims = True)
      with slim.arg_scope(shift.arg_scope(pr, reuse = reuse, train = train)):
        logits = shift.conv3d(
          net, pr.num_classes, [1, 1, 1], scope = 'lb/logits', 
          activation_fn = None, normalizer_fn = None)[:, 0, 0, 0, :]
        return ut.Struct(logits = logits, 
                         prob = tf.nn.softmax(logits), 
                         last_conv = net)

  elif pr.net_type == 'shift':
    with slim.arg_scope(shift.arg_scope(pr, reuse = reuse, train = train)):
      # To train the network without audio, you can set samples to be an all-zero array, and
      # set pr.use_sound = False.
      shift_net = shift.make_net(ims, samples, pr, reuse = reuse, train = train)
      if pr.use_dropout:
        shift_net.last_conv = slim.dropout(shift_net.last_conv, is_training = train)

      net = shift_net.last_conv
      net = tf.reduce_mean(net, [1, 2, 3], keep_dims = True)
      logits = shift.conv3d(
        net, pr.num_classes, [1, 1, 1], scope = 'lb/logits', 
        activation_fn = None, normalizer_fn = None)[:, 0, 0, 0, :]
      return ut.Struct(logits = logits, prob = tf.nn.softmax(logits), last_conv = net)
  elif pr.net_type == 'c3d':
    import c3d
    with slim.arg_scope(shift.arg_scope(reuse = reuse, train = train)):
      net = c3d.make_net(ims, samples, pr, reuse = reuse, train = train)
      net = net.last_conv
      net = tf.reduce_mean(net, [1, 2, 3], keep_dims = True)
      logits = c3d.conv3d(
        net, pr.num_classes, [1, 1, 1], scope = 'lb/logits', 
        activation_fn = None, normalizer_fn = None)[:, 0, 0, 0, :]
      return ut.Struct(logits = logits, prob = tf.nn.softmax(logits), last_conv = net)
  else: 
    raise RuntimeError()
Ejemplo n.º 2
0
def make_net(ims,
             sfs,
             spec,
             phase,
             pr,
             reuse=True,
             train=True,
             vid_net_full=None):
    if pr.mono:
        print 'Using mono!'
        sfs = make_mono(sfs, tile=True)

    if vid_net_full is None:
        if pr.net_style == 'static':
            n = shape(ims, 1)
            if 0:
                ims_tile = tf.tile(ims[:, n / 2:n / 2 + 1], (1, n, 1, 1, 1))
            else:
                ims = tf.cast(ims, tf.float32)
                ims_tile = tf.tile(ims[:, n / 2:n / 2 + 1], (1, n, 1, 1, 1))
            vid_net_full = shift_net.make_net(ims_tile, sfs, pr, None, reuse,
                                              train)
        elif pr.net_style == 'no-im':
            vid_net_full = None
        elif pr.net_style == 'full':
            vid_net_full = shift_net.make_net(ims, sfs, pr, None, reuse, train)
        elif pr.net_style == 'i3d':
            with tf.variable_scope('RGB', reuse=reuse):
                import sep_i3d
                i3d_net = sep_i3d.InceptionI3d(1)
                vid_net_full = ut.Struct(
                    scales=i3d_net(ims, is_training=train))

    with slim.arg_scope(unet_arg_scope(pr, reuse=reuse, train=train)):  # test
        acts = []

        def conv(*args, **kwargs):
            out = conv2d(*args, activation_fn=None, **kwargs)
            acts.append(out)
            out = mu.lrelu(out, 0.2)
            return out

        def deconv(*args, **kwargs):
            args = list(args)
            if kwargs.get('do_pop', True):
                skip_layer = acts.pop()
            else:
                skip_layer = acts[-1]
            if 'do_pop' in kwargs:
                del kwargs['do_pop']
            x = args[0]
            if kwargs.get('concat', True):
                x = tf.concat([x, skip_layer], 3)
            if 'concat' in kwargs:
                del kwargs['concat']
            args[0] = tf.nn.relu(x)
            return deconv2d(*args, activation_fn=None, **kwargs)

        def merge_level(net, n):
            if vid_net_full is None:
                return net
            vid_net = tf.reduce_mean(vid_net_full.scales[n], [2, 3],
                                     keep_dims=True)
            vid_net = vid_net[:, :, 0, :, :]
            s = shape(vid_net)
            if shape(net, 1) != s[1]:
                vid_net = tf.image.resize_images(vid_net, [shape(net, 1), 1])
                print 'Video net before merge:', s, 'After:', shape(vid_net)
            else:
                print 'No need to resize:', s, shape(net)
            vid_net = tf.tile(vid_net, (1, 1, shape(net, 2), 1))
            net = tf.concat([net, vid_net], 3)
            acts[-1] = net
            return net

        num_freq = shape(spec, 2)
        net = tf.concat([
            ed(normalize_spec(spec, pr), 3),
            ed(normalize_phase(phase, pr), 3)
        ], 3)

        net = net[:, :, :pr.freq_len, :]
        net = conv(net, 64, 4, scope='gen/conv1', stride=[1, 2])
        net = conv(net, 128, 4, scope='gen/conv2', stride=[1, 2])
        net = conv(net, 256, 4, scope='gen/conv3', stride=2)
        net = merge_level(net, 0)
        net = conv(net, 512, 4, scope='gen/conv4', stride=2)
        net = merge_level(net, 1)
        net = conv(net, 512, 4, scope='gen/conv5', stride=2)
        net = merge_level(net, 2)
        net = conv(net, 512, 4, scope='gen/conv6', stride=2)
        net = conv(net, 512, 4, scope='gen/conv7', stride=2)
        net = conv(net, 512, 4, scope='gen/conv8', stride=2)
        net = conv(net, 512, 4, scope='gen/conv9', stride=2)

        net = deconv(net, 512, 4, scope='gen/deconv1', stride=2, concat=False)
        net = deconv(net, 512, 4, scope='gen/deconv2', stride=2)
        net = deconv(net, 512, 4, scope='gen/deconv3', stride=2)
        net = deconv(net, 512, 4, scope='gen/deconv4', stride=2)
        net = deconv(net, 512, 4, scope='gen/deconv5', stride=2)
        net = deconv(net, 256, 4, scope='gen/deconv6', stride=2)
        net = deconv(net, 128, 4, scope='gen/deconv7', stride=2)
        net = deconv(net, 64, 4, scope='gen/deconv8', stride=[1, 2])

        out_fg = deconv(net,
                        2,
                        4,
                        scope='gen/fg',
                        stride=[1, 2],
                        normalizer_fn=None,
                        do_pop=False)
        out_bg = deconv(net,
                        2,
                        4,
                        scope='gen/bg',
                        stride=[1, 2],
                        normalizer_fn=None,
                        do_pop=False)

        def process(out):
            pred_spec = out[..., 0]
            pred_spec = tf.tanh(pred_spec)
            pred_spec = unnormalize_spec(pred_spec, pr)

            pred_phase = out[..., 1]
            pred_phase = tf.tanh(pred_phase)
            pred_phase = unnormalize_phase(pred_phase, pr)

            val = soundrep.db_from_amp(0.) if pr.log_spec else 0.
            pred_spec = tf.pad(pred_spec,
                               [(0, 0), (0, 0),
                                (0, num_freq - shape(pred_spec, 2))],
                               constant_values=val)

            if pr.phase_type == 'pred':
                pred_phase = tf.concat([pred_phase, phase[..., -1:]], 2)
            elif pr.phase_type == 'orig':
                pred_phase = phase
            else:
                raise RuntimeError()

            # if ut.hastrue(pr, 'griffin_lim'):
            #   print 'using griffin-lim'
            #   pred_wav = griffin_lim(pred_spec, pred_phase, pr)
            # else:
            pred_wav = istft(pred_spec, pred_phase, pr)
            return pred_spec, pred_phase, pred_wav

        pred_spec_fg, pred_phase_fg, pred_wav_fg = process(out_fg)
        pred_spec_bg, pred_phase_bg, pred_wav_bg = process(out_bg)

        return ut.Struct(
            pred_spec_fg=pred_spec_fg,
            pred_wav_fg=pred_wav_fg,
            pred_phase_fg=pred_phase_fg,
            pred_spec_bg=pred_spec_bg,
            pred_phase_bg=pred_phase_bg,
            pred_wav_bg=pred_wav_bg,
            vid_net=vid_net_full,
        )