コード例 #1
0
    def hybrid_forward(self, F, data):
        # data.shape is a triple of (16, 1, 64). Need to eliminate that redundant second dimension and transpose it
        # before attaching the embeddings
        data = squeeze(data)
        data = data.T
        embedded = self.embedding(data)

        x = self.rnn(embedded)
        x - self.activation1(x)

        # Swap the first and second axes to bring it from (length, batch size, width) to (batch size, length, width),
        # before passing it to the outer layer (only recurrent layers use the first ordering).
        x = swapaxes(x, 0, 1)

        x = self.out(x)

        return x
コード例 #2
0
def squeeze(input, dim):
    return nd.squeeze(input, axis=dim)
コード例 #3
0
ファイル: train.py プロジェクト: zhangjf2018/gluon-cv
def train(args, dataset, generator, discriminator):

    step = int(math.log2(args.init_size)) - 2
    resolution = 4 * 2**step
    loader = sample_data(dataset, args.batch.get(resolution,
                                                 args.batch_default),
                         resolution)
    data_loader = iter(loader)

    adjust_lr(g_optimizer, args.lr.get(resolution, args.lr_default))
    adjust_lr(d_optimizer, args.lr.get(resolution, args.lr_default))

    alpha = 0
    used_sample = 0

    max_step = int(math.log2(args.max_size)) - 2
    final_progress = False

    requires_grad(generator, False)
    requires_grad(discriminator, True)

    pbar = tqdm(range(200_000))

    for i in pbar:

        alpha = min(1, 1 / args.phase * (used_sample + 1))

        if (resolution == args.init_size
                and args.ckpt is None) or final_progress:
            alpha = 1

        if used_sample > args.phase * 2:
            used_sample = 0
            step += 1

            if step > max_step:
                step = max_step
                final_progress = True
                ckpt_step = step + 1

            else:
                alpha = 0
                ckpt_step = step

            resolution = 4 * 2**step

            loader = sample_data(
                dataset, args.batch.get(resolution, args.batch_default),
                resolution)
            data_loader = iter(loader)

            generator.save_parameters(
                osp.join(args.ckpt_dir, f'generator_step-{ckpt_step}.params'))
            discriminator.save_parameters(
                osp.join(args.ckpt_dir,
                         f'discriminator_step-{ckpt_step}.params'))
            g_running.save_parameters(
                osp.join(args.ckpt_dir, f'g_running_step-{ckpt_step}.params'))

            adjust_lr(g_optimizer, args.lr.get(resolution, args.lr_default))
            adjust_lr(d_optimizer, args.lr.get(resolution, args.lr_default))

        try:
            real_image = next(data_loader)

        except (OSError, StopIteration):
            data_loader = iter(loader)
            real_image = next(data_loader)

        used_sample += real_image.shape[0]
        b_size = real_image.shape[0]
        real_image_list = gluon.utils.split_and_load(real_image,
                                                     ctx_list=context,
                                                     batch_axis=0)

        if args.mixing and random.random() < 0.9:
            gen_in11, gen_in12, gen_in21, gen_in22 = nd.random.randn(
                4, b_size, code_size).split(4, 0)
            gen_in1 = [
                nd.squeeze(gen_in11, axis=0),
                nd.squeeze(gen_in12, axis=0)
            ]
            gen_in2 = [
                nd.squeeze(gen_in21, axis=0),
                nd.squeeze(gen_in22, axis=0)
            ]

        else:
            gen_in1, gen_in2 = nd.random.randn(2, b_size,
                                               code_size).split(2, 0)
            gen_in1 = nd.squeeze(gen_in1, axis=0)
            gen_in2 = nd.squeeze(gen_in2, axis=0)

        gen_in1_list = gluon.utils.split_and_load(gen_in1,
                                                  ctx_list=context,
                                                  batch_axis=0)
        gen_in2_list = gluon.utils.split_and_load(gen_in2,
                                                  ctx_list=context,
                                                  batch_axis=0)

        if args.loss == 'wgan':
            fake_predict_list = []
            real_predict_list = []
            D_loss_list = []
            with autograd.record():
                for _, (rl_image,
                        g1) in enumerate(zip(real_image_list, gen_in1_list)):
                    real_predict = discriminator(rl_image, step, alpha)
                    real_predict = -real_predict.mean()
                    real_predict_list.append(real_predict)

                    fake_image = generator(g1, step, alpha)
                    fake_predict = discriminator(fake_image.detach(), step,
                                                 alpha)
                    fake_predict = fake_predict.mean()
                    fake_predict_list.append(fake_predict)

                    D_loss_list.append(real_predict + fake_predict)

            autograd.backward(loss_list)

        elif args.loss == 'r1':
            # Not able to implement r1 loss
            raise Exception(
                'r1 loss has not been implemented, please use wgan loss')
        else:
            raise Exception('Not valid loss, please use wgan loss')

        if i % 10 == 0:
            real_predict_val = [i.asnumpy() for i in real_predict_list]
            fake_predict_val = [i.asnumpy() for i in fake_predict_list]
            d_real_val = np.concatenate(real_predict_val).mean()
            d_fake_val = np.concatenate(fake_predict_val).mean()
            disc_loss_val = d_real_val + d_fake_val

        d_optimizer.step(b_size, ignore_stale_grad=True)

        if (i + 1) % n_critic == 0:

            requires_grad(generator, True)
            requires_grad(discriminator, False)

            if args.loss == 'wgan-gp':
                predict_list = []
                with autograd.record():
                    for _, g2 in enumerate(gen_in2_list):
                        fake_image = generator(g2, step, alpha)
                        predict = discriminator(fake_image, step, alpha)
                        predict = -predict.mean()
                        predict_list.append(predict)
                autograd.backward(predict_list)
            elif args.loss == 'r1':
                # Not able to implement r1 loss
                raise Exception(
                    'r1 loss has not been implemented, please use wgan loss')
            else:
                raise Exception('Not valid loss, please use wgan loss')

            if i % 10 == 0:
                predict_val = [i.asnumpy() for i in predict_list]
                gen_loss_val = np.concatenate(predict_val).mean()

            g_optimizer.step(b_size, ignore_stale_grad=True)

            accumulate(g_running, generator)

            requires_grad(generator, False)
            requires_grad(discriminator, True)

        if (i + 1) % 100 == 0:
            images = []

            gen_i, gen_j = args.gen_sample.get(resolution, (10, 5))

            for _ in range(gen_i):
                results = g_running(
                    nd.random.randn(gen_j, code_size, ctx=mx.gpu(0)), step,
                    alpha)
                for r in results:
                    images.append(r)

            plot_images(images, osp.join(args.out,
                                         f'{str(i + 1).zfill(6)}.png'), gen_i,
                        gen_j)

        if (i + 1) % 1000 == 0:
            generator.save_parameters(
                osp.join(args.ckpt_dir, f'g-{str(i + 1).zfill(6)}.params'))
            discriminator.save_parameters(
                osp.join(args.ckpt_dir, f'd-{str(i + 1).zfill(6)}.params'))
            g_running.save_parameters(
                osp.join(args.ckpt_dir,
                         f'g_running-{str(i + 1).zfill(6)}.params'))

        state_msg = (
            f'Size: {4 * 2 ** step}; G: {gen_loss_val:.1f}; D: {disc_loss_val:.1f};'
            f'D_real: {d_real_val:.1f}; D_fake: {d_fake_val:.1f}; Alpha: {alpha:.4f}'
        )

        logger.info(
            f'Size: {4 * 2 ** step}; G: {gen_loss_val:.1f}; D: {disc_loss_val:.1f}\
            D_real: {d_real_val:1f}; D_fake: {d_fake_val:1f}; Alpha: {alpha:.4f}'
        )

        pbar.set_description(state_msg)
コード例 #4
0
    def forward(self, input_vec, loss=None):
        # print('************* ' + str(input_vec.shape[1]) + ' *************')
        # print('############# ' + str(input_vec.shape) + ' #############')
        assert input_vec.shape[1] == self.input_dimension

        # get inputs for every slot(including global)
        inputs = {}
        for slot in self.slots:
            inputs[slot] = input_vec[:, self.slot_dimension[slot][0]:self.
                                     slot_dimension[slot][1]]
        input_global = []
        for seg in self.global_dimension:
            input_global.append(input_vec[:, seg[0]:seg[1]])
        inputs['global'] = nd.concat(*input_global, dim=1)

        layer = []
        # inputs -> first_hidden_layer
        if (not self.sort_input_vec) and self.state_feature != 'dip':
            layer.append([])
            for slot in self.slots:
                layer[0].append(self.input_trans[slot](inputs[slot]))
            layer[0].append(self.input_trans['global'](inputs['global']))
        elif self.state_feature == 'dip':
            sorted_inputs = []
            for slot in self.slots:
                sorted_inputs.append(inputs[slot])
            sorted_inputs.append(inputs['global'])
            layer.append(self.input_trans(sorted_inputs, loss))
        elif self.sort_input_vec:
            sorted_inputs = []
            for slot in self.slots:
                tmp = inputs[slot][:, :-2].sort(is_ascend=False)
                if tmp.shape[1] < 20:
                    tmp = nd.concat(tmp,
                                    nd.zeros((tmp.shape[0], 20 - tmp.shape[1]),
                                             ctx=CTX),
                                    dim=1)
                else:
                    tmp = nd.slice_axis(tmp, axis=1, begin=0, end=20)
                sorted_inputs.append(
                    nd.concat(tmp, inputs[slot][:, -2:], dim=1))
            sorted_inputs.append(inputs['global'])
            layer.append(self.input_trans(sorted_inputs, loss))

        # hidden_layers
        for i in range(self.hidden_layers - 1):
            if self.recurrent_mode is False:
                # equal to 'layer.append(self.ma_trans[i](layer[-1], loss))'
                layer.append(self.ma_trans[i](layer[i], loss))
            else:
                layer.append(self.ma_trans(layer[i], loss))

        if self.share_last_layer is False:
            # dropout of last hidden layer
            for j in range(len(self.slots)):
                layer[-1][j] = self.local_out_drop_op(layer[-1][j])
            layer[-1][-1] = self.global_out_drop_op(layer[-1][-1])

            # last_hidden_layer -> outputs
            outputs = []
            slotv_probs = []
            tmp_ave = nd.zeros_like(layer[-1][0])
            for i in range(len(self.slots) + 1):
                if self.use_dueling is False:
                    outputs.append(self.output_trans[i](layer[-1][i]))
                else:
                    if i < len(self.slots):
                        cur_slot_prob = self.output_trans_local_slotP(
                            layer[-1][i])
                        tmp_ave = tmp_ave + layer[-1][i]
                    else:
                        cur_slot_prob = self.output_trans_global_slotP(
                            layer[-1][i])
                        slots_concat = nd.concat(*[tmp_ave, layer[-1][i]],
                                                 dim=1)

                    slotv_probs.append(cur_slot_prob)

            batch_slotv_prob = nd.softmax(nd.concat(*slotv_probs, dim=1))
            batch_value = nd.squeeze(self.output_trans_value(slots_concat))

            # print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
            # print(batch_slotv_prob)
            # print(batch_slotv_prob.shape)
            # print(batch_value)
            # print(batch_value.shape)
            # exit(0)

        return batch_slotv_prob, batch_value
コード例 #5
0
    def forward(self, input_vec, pop_art_hyper=None, loss=None, training=True):
        # print('************* ' + str(input_vec.shape[1]) + ' *************')
        # print('############# ' + str(input_vec.shape) + ' #############')
        assert input_vec.shape[1] == self.input_dimension

        # get inputs for every slot(including global)
        inputs = {}
        for slot in self.slots:
            inputs[slot] = input_vec[:, self.slot_dimension[slot][0]:self.slot_dimension[slot][1]]
        input_global = []
        for seg in self.global_dimension:
            input_global.append(input_vec[:, seg[0]:seg[1]])
        inputs['global'] = nd.concat(*input_global, dim=1)

        layer = []
        # inputs -> first_hidden_layer
        if (not self.sort_input_vec) and self.state_feature != 'dip':
            layer.append([])
            for slot in self.slots:
                layer[0].append(self.input_trans[slot](inputs[slot]))
            layer[0].append(self.input_trans['global'](inputs['global']))
        elif self.state_feature == 'dip':
            sorted_inputs = []
            for slot in self.slots:
                sorted_inputs.append(inputs[slot])
            sorted_inputs.append(inputs['global'])
            layer.append(self.input_trans.forward(sorted_inputs, loss, training=training))
        elif self.sort_input_vec:
            sorted_inputs = []
            for slot in self.slots:
                tmp = inputs[slot][:, :-2].sort(is_ascend=False)
                if tmp.shape[1] < 20:
                    tmp = nd.concat(tmp, nd.zeros((tmp.shape[0], 20 - tmp.shape[1]), ctx=CTX), dim=1)
                else:
                    tmp = nd.slice_axis(tmp, axis=1, begin=0, end=20)
                sorted_inputs.append(nd.concat(tmp, inputs[slot][:, -2:], dim=1))
            sorted_inputs.append(inputs['global'])
            layer.append(self.input_trans.forward(sorted_inputs, loss, training=training))

        # hidden_layers
        for i in range(self.hidden_layers - 1):
            if self.recurrent_mode is False:
                # equal to 'layer.append(self.ma_trans[i](layer[-1], loss))'
                layer.append(self.ma_trans[i](layer[i], loss))
            else:
                layer.append(self.ma_trans(layer[i], loss))

        if self.share_last_layer is False:
            # dropout of last hidden layer
            for j in range(len(self.slots)):
                layer[-1][j] = self.local_out_drop_op(layer[-1][j])
            layer[-1][-1] = self.global_out_drop_op(layer[-1][-1])

            # last_hidden_layer -> outputs
            outputs = []
            slotv_probs = []
            attention = []
            # tmp_ave = nd.zeros_like(layer[-1][0])
            for i in range(len(self.slots) + 1):
                if self.use_dueling is False:
                    outputs.append(self.output_trans[i](layer[-1][i]))
                else:
                    if i < len(self.slots):
                        cur_slotv_prob = self.output_trans_local_valueP.forward(layer[-1][i], training=training)
                        cur_slotv_prob = nd.softmax(cur_slotv_prob)
                    else:
                        cur_slotv_prob = self.output_trans_global_valueP.forward(layer[-1][i], training=training)
                        cur_slotv_prob = nd.softmax(cur_slotv_prob)

                    if i < len(self.slots):
                        cur_slot_prob = self.output_trans_local_slotP.forward(layer[-1][i], training=training)
                        attention.append(self.attention_layer.forward(nd.concat(*[layer[-1][i], layer[-1][-1]], dim=1), training=training))
                        # tmp_ave = tmp_ave + layer[-1][i]
                    else:
                        cur_slot_prob = self.output_trans_global_slotP.forward(layer[-1][i], training=training)
                        softmax_attention = nd.softmax(nd.concat(*attention))
                        split_softmax_attention = nd.split(softmax_attention, num_outputs=len(self.slots))
                        tmp_ave = nd.zeros_like(layer[-1][0])
                        for j in range(len(self.slots)):
                            layer[-1][j] = layer[-1][j]*split_softmax_attention[j]
                            tmp_ave = tmp_ave + layer[-1][j]
                        slots_concat = nd.concat(*[tmp_ave, layer[-1][i]], dim=1)

                    cur_slotv_prob = cur_slotv_prob*cur_slot_prob
                    
                    slotv_probs.append(cur_slotv_prob)

            if pop_art_hyper != None:
                sigma, sigma_prime, mu, mu_prime = pop_art_hyper
                # self.private_w.set_data((sigma/sigma_prime)*self.private_w.data())
                # self.private_b.set_data((sigma*self.private_b.data() + mu - mu_prime)/sigma_prime)
                batch_slotv_prob = nd.softmax(nd.concat(*slotv_probs, dim=1))
                batch_value = nd.squeeze((sigma/sigma_prime)*self.private_w.data()*(self.output_trans_value.forward(slots_concat, training=training))+((sigma*self.private_b.data() + mu - mu_prime)/sigma_prime))
            else:
                batch_slotv_prob = nd.softmax(nd.concat(*slotv_probs, dim=1))
                batch_value = nd.squeeze(self.private_w.data()*(self.output_trans_value.forward(slots_concat, training=training))+self.private_b.data())


            # print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')
            # print(batch_slotv_prob)
            # print(batch_slotv_prob.shape)
            # print(batch_value)
            # print(batch_value.shape)
            # exit(0)
            
        return batch_slotv_prob, batch_value
コード例 #6
0
def validate(net, val_data, ctx, eval_metric, args):
    """Test on validation dataset."""
    clipper = gcv.nn.bbox.BBoxClipToImage()
    eval_metric.reset()
    if not args.disable_hybridization:
        # input format is differnet than training, thus rehybridization is needed.
        net.hybridize(static_alloc=args.static_alloc)
    rpn_gt_recalls = []
    for batch in val_data:
        batch = split_and_load(batch, ctx_list=ctx)
        det_bboxes = []
        det_ids = []
        det_scores = []
        gt_bboxes = []
        gt_ids = []
        gt_difficults = []
        for x, y, im_scale in zip(*batch):
            # get prediction results
            ids, scores, bboxes, roi = net(x)
            det_ids.append(ids)
            det_scores.append(scores)
            # clip to image size
            det_bboxes.append(clipper(bboxes, x))
            # rescale to original resolution
            im_scale = im_scale.reshape((-1)).asscalar()
            det_bboxes[-1] *= im_scale
            # split ground truths
            gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
            gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
            gt_bboxes[-1] *= im_scale
            gt_difficults.append(
                y.slice_axis(axis=-1, begin=5, end=6
                             ) if y.shape[-1] > 5 else None)

            gt_label = y[:, :, 4:5]
            gt_box = y[:, :, :4]
            for i in range(gt_label.shape[0]):
                _gt_label = nd.squeeze(gt_label[i])
                match_mask = nd.zeros_like(_gt_label)
                # 如果两个box面积都是0,iou是0
                iou = nd.contrib.box_iou(roi[i], gt_box[i], format='corner')
                num_raw = iou.shape[1]
                # 为每个gt box分配anchor
                # 参考http://zh.d2l.ai/chapter_computer-vision/anchor.html#%E6%A0%87%E6%B3%A8%E8%AE%AD%E7%BB%83%E9%9B%86%E7%9A%84%E9%94%9A%E6%A1%86
                for _ in range(_gt_label.shape[0]):
                    _iou = iou.reshape(-1)
                    max = nd.max(_iou, axis=0)
                    if max < 0.5:
                        break
                    pos = nd.argmax(_iou, axis=0)
                    raw = (pos / num_raw).astype(np.int64)
                    col = pos % num_raw
                    iou[raw, :] = 0
                    iou[:, col] = 0
                    match_mask[col] = 1
                match_mask = nd.contrib.boolean_mask(match_mask,
                                                     _gt_label != -1)
                rpn_gt_recalls.append(nd.mean(match_mask).asscalar())

        # update metric
        for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(
                det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids,
                gt_difficults):
            eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id,
                               gt_diff)
    rpn_gt_recall = np.mean(rpn_gt_recalls)
    print("RPN GT Recall", rpn_gt_recall)
    return eval_metric.get()
コード例 #7
0
def verify_broadcast_like_dynamic(xshp, wshp, lhs_axes, rhs_axes):
    x_np = np.random.uniform(size=xshp)
    w_np = np.random.uniform(size=wshp)
    x = nd.array(x_np)
    w = nd.array(w_np)

    # org op
    y = nd.broadcast_like(x, w,
        lhs_axes=lhs_axes, rhs_axes=rhs_axes)
    print(y.shape)

    # rewrite op
    xndims, wndims = len(xshp), len(wshp)
    if lhs_axes is None or rhs_axes is None:
        assert xndims == wndims and lhs_axes is None \
            and rhs_axes is None
        z = _broadcast_like(x, w)
    else:
        lhs_axes, lndims = list(lhs_axes), len(lhs_axes)
        rhs_axes, rndims = list(rhs_axes), len(rhs_axes)
        assert lndims == rndims > 0

        lhs_axes = tuple([v+xndims if v<0 else v for v in lhs_axes])
        assert all([0<=v<xndims for v in list(lhs_axes)])

        rhs_axes = tuple([v+wndims if v<0 else v for v in rhs_axes])
        assert all([0<=v<wndims for v in list(rhs_axes)])

        assert all([xshp[lhs_axes[i]] == 1 for i in range(lndims)])

        batch_axes = [0]
        flg = all([batch_axis not in rhs_axes \
            for batch_axis in batch_axes])
        if flg:
            cnts = {v: wshp[rhs_axes[i]] \
                for i, v in enumerate(lhs_axes)}
            reps = tuple([cnts[v] if v in lhs_axes else 1 \
                for v in range(xndims)])
            z = nd.tile(x, reps=reps)
        else:
            axis_map = {}
            for i, v in enumerate(lhs_axes):
                axis_map[v] = rhs_axes[i]
            for batch_axis in batch_axes:
                assert sum([1 if v == batch_axis else 0 \
                    for k, v in axis_map.items()]) <= 1, \
                    "multiple broadcast on batch_axis: %s, " + \
                    "which is not support by dynamic shape fusion." % \
                    batch_axis
            assert wndims < 6, \
                "slice can manipulate at most 5d"

            # reduce shape to 1 for non-broadcast dimensions
            begin = tuple([0]*wndims)
            end = tuple([wshp[v] if v in axis_map.values() else 1 \
                for v in range(wndims)])
            w = nd.slice(w, begin=begin, end=end)

            # decompose k1->v, k2->v into k1->v, k2->v2
            # which make axis
            while True:
                vs, flag, paxis_map = set(), True, axis_map
                for pk, pv in paxis_map.items():
                    if pv not in vs:
                        vs.add(pv)
                        continue
                    flag = False
                    axis_map = {k: (v+1 if v>pv or k==pk else v) \
                        for k, v in axis_map.items()}
                    w = nd.expand_dims(w, axis=pv)
                    w = nd.repeat(w, axis=pv, repeats=wshp[pv])
                    wshp = wshp[:pv] + (wshp[pv],) + wshp[pv:]
                    break
                if flag:
                    break
            wndims = len(wshp)

            # trim wndims if not equal to xndims
            v = 0
            while wndims > xndims:
                while v in axis_map.values():
                    v += 1
                w = nd.squeeze(w, axis=v)
                wndims -= 1
                axis_map = {k: (nv-1 if nv > v else nv) \
                    for k, nv in axis_map.items()}
            while wndims < xndims:
                w = nd.expand_dims(w, axis=wndims)
                wndims += 1
            axes = list(range(wndims))
            while True:
                dels = [k for k, v in axis_map.items() if k==v]
                for k in dels:
                    del axis_map[k]
                if not axis_map:
                    break
                keys = list(axis_map.keys())
                k, v = keys[0], axis_map[keys[0]]
                axes[k], axes[v] = axes[v], axes[k]
                for nk in keys:
                    nv = axis_map[nk]
                    if nv == k:
                        axis_map[nk] = v
                    elif nv == v:
                        axis_map[nk] = k
            axes = tuple(axes)
            if axes != tuple(range(wndims)):
                assert wndims < 7, \
                    "slice can manipulate at most 6d"
                w = nd.transpose(w, axes=axes)
            z = _broadcast_like(x, w)
    print(z.shape)

    # compare
    assert z.shape == y.shape
    zn, zp = get_norm(z)
    yn, yp = get_norm(y)
    rn = np.linalg.norm(zp-yp)
    print(zn, yn, rn)