def _ranking(self, inputs, predictions): """ Reranking generated responses. """ src_token = inputs["src_token"] src_mask = inputs["src_mask"] src_pos = inputs["src_pos"] src_type = inputs["src_type"] src_turn = inputs["src_turn"] src_embed = self.embedder(src_token, src_pos, src_type, src_turn) batch_size, num_latent, tgt_seq_len = predictions.shape # shape: [batch_size, num_latent, seq_len, 1] preds_token = F.unsqueeze(predictions, [3]) preds_mask = F.not_equal(preds_token, self.padding_idx, "int64") preds_pos = layers.range(0, tgt_seq_len, 1, dtype="float32") preds_pos = F.unsqueeze(preds_pos, [0, 0, 1]) preds_pos = layers.expand(preds_pos, [batch_size, num_latent, 1, 1]) preds_pos = layers.cast(preds_pos, "int64") preds_type = layers.zeros_like(preds_token) preds_turn = layers.zeros_like(preds_token) scores = [] for i in range(num_latent): pred_token = preds_token[:, i] pred_mask = preds_mask[:, i] pred_pos = preds_pos[:, i] pred_type = preds_type[:, i] pred_turn = preds_turn[:, i] input_mask = layers.concat([src_mask, pred_mask], axis=1) input_mask.stop_gradient = True pred_embed = self.embedder(pred_token, pred_pos, pred_type, pred_turn) embed = layers.concat([src_embed, pred_embed], axis=1) embed = self.embed_layer_norm(embed) mask_embed = self.mask_embed mask_embed = layers.expand(mask_embed, [batch_size, 1, 1]) mask_embed = self.embed_layer_norm(mask_embed) out = layers.concat([mask_embed, embed], axis=1) mask = self._create_mask(input_mask, append_head=True) for layer in self.layers: out = layer(out, mask, None) mask_embed = out[:, 0] score = self.discriminator(mask_embed) scores.append(score[:, 0]) scores = layers.stack(scores, axis=1) return scores
def forward(self, pred, target): target = 1 - target[:, 0] batch_size, vector_size = pred.shape[0], pred.shape[1] pred = L.l2_normalize(pred, axis=1, epsilon=1e-10) square_norm = L.reduce_sum(L.square(pred), dim=1) dist = L.elementwise_add(-2.0 * L.matmul(pred, pred, transpose_y=True), square_norm, axis=0) dist = L.elementwise_add(dist, square_norm, axis=1) dist = L.elementwise_max(dist, L.zeros_like(dist)) dist = L.sqrt(dist) ap_dist = L.reshape(dist, (0, 0, 1)) an_dist = L.reshape(dist, (0, 1, -1)) loss = L.expand(ap_dist, (1, 1, batch_size)) - L.expand( an_dist, (1, batch_size, 1)) + self.magin indice_equal = L.diag( L.fill_constant((batch_size, ), dtype='float32', value=1.0)) indice_not_equal = 1.0 - indice_equal broad_matrix = L.expand(L.reshape(target, (-1, 1)), (1, batch_size)) + L.expand( L.reshape(target, (1, -1)), (batch_size, 1)) pp = L.cast(L.equal(broad_matrix, L.zeros_like(broad_matrix)), dtype='float32') pp = L.reshape(indice_not_equal * pp, (0, 0, 1)) pn = L.cast(L.equal(broad_matrix, L.zeros_like(broad_matrix) + 1), dtype='float32') pn = L.reshape(indice_not_equal * pn, (1, 0, -1)) apn = L.expand(pp, (1, 1, batch_size)) * L.expand(pn, (batch_size, 1, 1)) loss = loss * L.cast(apn, dtype='float32') loss = L.elementwise_max(loss, L.zeros_like(loss)) num_tri = L.reduce_sum( L.cast(L.greater_than(loss, L.zeros_like(loss)), dtype='float32')) loss = L.reduce_sum(loss) * self.loss_weight / (num_tri + 1e-16) return loss
def forward(self, src, mask, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape src = L.reshape(src, (bs, c, -1)) # [bs, c, h * w] src = L.transpose(src, (0, 2, 1)) # [bs, h * w, c] pos_embed = L.reshape(pos_embed, (bs, pos_embed.shape[1], -1)) # [bs, c, h * w] pos_embed = L.transpose(pos_embed, (0, 2, 1)) # [bs, h * w, c] query_embed = L.unsqueeze(query_embed, [0]) # [1, num_queries, c_q] query_embed = L.expand(query_embed, (bs, 1, 1)) # [bs, num_queries, c_q] mask = L.reshape(mask, (bs, -1)) # [bs, h * w] tgt = L.zeros_like(query_embed) # [bs, num_queries, c_q] memory, encoder_attn_weights = self.encoder( src, src_mask=mask, pos=pos_embed) # [bs, h * w, c] hs, decoder_attn_weights = self.decoder(tgt, memory, memory_mask=mask, pos=pos_embed, query_pos=query_embed) # hs: [num_inter, bs, num_queries, c_q] memory = L.transpose(memory, (0, 2, 1)) # [bs, c, h * w] memory = L.reshape(memory, (bs, c, h, w)) # [bs, c, h, w] return hs, memory, encoder_attn_weights, decoder_attn_weights
def compute_mask_loss(self, occ_mask, warped_image, tgt_image): """ Compute losses on the generated occlusion mask. Args: occ_mask (tensor): Generated occlusion masks. warped_image (tensor): Warped image using the flow map. tgt_image (tensor): Target image for the warped image. Returns: (tensor): Loss for the mask. """ loss_mask = dg.to_variable(np.zeros((1, )).astype("float32")) if occ_mask is not None: dummy0 = L.zeros_like(occ_mask) dummy1 = L.ones_like(occ_mask) # Compute the confidence map based L1 distance between warped and GT image. img_diff = L.reduce_sum(L.abs(warped_image - tgt_image), 1, keep_dim=True) conf = L.clip(1 - img_diff, 0, 1) # Force mask value to be small if warped image is similar to GT, and vice versa. loss_mask = self.criterionMasked(occ_mask, dummy0, conf) loss_mask += self.criterionMasked(occ_mask, dummy1, 1 - conf) return loss_mask
def erniesage_v2_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name): feature = L.unsqueeze(feature, [-1]) msg = gw.send(ernie_send, nfeat_list=[("term_ids", feature)]) neigh_feature = gw.recv( msg, lambda feat: F.layers.sequence_pool(feat, pool_type="sum")) term_ids = feature cls = L.fill_constant_batch_size_like(term_ids, [-1, 1, 1], "int64", 1) term_ids = L.concat([cls, term_ids], 1) term_ids.stop_gradient = True ernie = ErnieModel(term_ids, L.zeros_like(term_ids), config=self.config.ernie_config) self_feature = ernie.get_pooled_output() self_feature = L.fc( self_feature, hidden_size, act=act, param_attr=F.ParamAttr(name=name + "_l", learning_rate=learning_rate), ) neigh_feature = L.fc( neigh_feature, hidden_size, act=act, param_attr=F.ParamAttr(name=name + "_r", learning_rate=learning_rate), ) output = L.concat([self_feature, neigh_feature], axis=1) output = L.l2_normalize(output, axis=1) return output
def batch_scatter(ref, indices, updates, in_place=False, overwrite=False): """Scatter updates to ref, according to corrensponding index in indices in each batch. Currently, it only support 2d Tensor. Args: ref (Variable): with shape [batch_size, ...] indices (Variable): with shape [batch_size, 1] updates (Variable): with shape [batch_size] in_place (bool): if True, scatter result will be assign to ref. otherwise, a new Tensor will be returned. Default is False. overwrite (bool): if True, scatter will over write corrensponding elements. Default is False. Returns: TODO Raises: NULL Examples: ref [[1, 1, 1], [1, 1, 1]] indices [[2], [1]] updates [2, 3] return [[1, 1, 2], [1, 3, 1]] """ ref_dtype = ref.dtype if ref_dtype not in PaddleVarType.floats: ref_in = layers.cast(ref, dtype='float32') else: ref_in = ref if updates.dtype != ref_in.dtype: updates = layers.cast(updates, dtype=ref_in.dtype) batch_size = layers.cast(layers.shape(ref_in)[0], dtype=indices.dtype) zero = layers.fill_constant(shape=[1], dtype=indices.dtype, value=0) one = layers.fill_constant(shape=[1], dtype=indices.dtype, value=1) batch_indices = layers.unsqueeze( layers.range(zero, batch_size, one, dtype=indices.dtype), [1]) coord = layers.concat([batch_indices, indices], axis=1) if overwrite: mask = layers.gather_nd(ref_in, coord) mask = layers.elementwise_sub(layers.zeros_like(mask), mask) ref_in = layers.scatter_nd_add(ref_in, coord, mask) output = layers.scatter_nd_add(ref_in, coord, updates) if ref_dtype not in PaddleVarType.floats: output = layers.cast(output, dtype=ref_dtype) if in_place: layers.assign(output, ref) return ref else: return output
def _debug_summary(self, input_mask): #histogram seqlen_before_pad = L.cast(L.reduce_sum(input_mask, dim=1), dtype='float32') seqlen_after_pad = L.reduce_sum( L.cast(L.zeros_like(input_mask), dtype='float32') + 1.0, dim=1) pad_num = seqlen_after_pad - seqlen_before_pad pad_rate = pad_num / seqlen_after_pad
def compute_mask_losses(self, occ_mask, fake_image, warped_image, tgt_label, tgt_image, fg_mask, ref_fg_mask, body_mask_diff): """ Compute losses on the generated occlusion masks. Args: occ_mask (tensor or list of tensors): Generated occlusion masks. fake_image (tensor): Generated image. warped_image (tensor or list of tensors): Warped images using the flow maps. tgt_label (tensor): Target label map. tgt_image (tensor): Target image for the warped image. fg_mask (tensor): Foreground mask for the reference image. body_fg_mask (tensor): Difference between warped body part map and target body part map. Used for pose dataset only. """ loss_mask = dg.to_variable(np.zeros((1, )).astype("float32")) if isinstance(occ_mask, list): # Compute occlusion mask losses for both warping reference -> target and previous -> target. for i in range(len(occ_mask)): loss_mask += self.compute_mask_loss(occ_mask[i], warped_image[i], tgt_image) else: # Compute loss for warping either reference or previous images. loss_mask += self.compute_mask_loss(occ_mask, warped_image, tgt_image) if self.warp_ref: ref_occ_mask = occ_mask[0] dummy0 = L.zeros_like(ref_occ_mask) dummy1 = L.ones_like(ref_occ_mask) if self.for_pose_dataset: # Enforce output to use more warped reference image for face region. face_mask = L.unsqueeze(get_face_mask(tgt_label[:, 2]), [1]) face_mask = L.pool2d(face_mask, pool_size=15, pool_type='avg', pool_stride=1, pool_padding=7) loss_mask += self.criterionMasked(ref_occ_mask, dummy0, face_mask) loss_mask += self.criterionMasked(fake_image, warped_image[0], face_mask) # Enforce output to use more hallucinated image for discrepancy # regions of body part masks between warped reference and target image. loss_mask += self.criterionMasked(ref_occ_mask, dummy1, body_mask_diff) if self.has_fg: # Enforce output to use more hallucinated image for discrepancy regions # of foreground masks between reference and target image. fg_mask_diff = ((ref_fg_mask - fg_mask) > 0).astype("float32") loss_mask += self.criterionMasked(ref_occ_mask, dummy1, fg_mask_diff) return loss_mask
def forward(self, mu, logvar=None): """ Compute loss Args: mu (tensor): mean logvar (tensor): logarithm of variance """ if logvar is None: logvar = L.zeros_like(mu) return -0.5 * L.reduce_sum(1 + logvar - L.pow(mu, 2) - L.exp(logvar))
def build_embedding(self, graph_wrappers, term_ids): term_ids = L.unsqueeze(term_ids, [-1]) ernie_config = self.config.ernie_config ernie = ErnieModel(src_ids=term_ids, sentence_ids=L.zeros_like(term_ids), task_ids=None, config=ernie_config, use_fp16=False, name="student_") feature = ernie.get_pooled_output() return feature
def decrement(self): new_scale = self.scale / self.factor one = layers.fill_constant(shape=[1], dtype='float32', value=1.0) less_than_one = layers.less_than(new_scale, one) with layers.Switch() as switch: with switch.case(less_than_one): layers.assign(one, self.scale) with switch.default(): layers.assign(new_scale, self.scale) layers.assign(layers.zeros_like(self.good_steps), self.good_steps)
def forward(self, features): src_ids, sent_ids = features dtype = 'float16' if self.hparam['fp16'] else 'float32' zero = L.fill_constant([1], dtype='int64', value=0) input_mask = L.cast(L.logical_not(L.equal(src_ids, zero)), dtype) # assume pad id == 0 #input_mask = L.unsqueeze(input_mask, axes=[2]) d_shape = L.shape(src_ids) seqlen = d_shape[1] batch_size = d_shape[0] pos_ids = L.unsqueeze(L.range(0, seqlen, 1, dtype='int32'), axes=[0]) pos_ids = L.expand(pos_ids, [batch_size, 1]) pos_ids = L.unsqueeze(pos_ids, axes=[2]) pos_ids = L.cast(pos_ids, 'int64') pos_ids.stop_gradient = True input_mask.stop_gradient = True task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment task_ids.stop_gradient = True bert = ErnieModel( src_ids=src_ids, position_ids=pos_ids, sentence_ids=sent_ids, task_ids=task_ids, input_mask=input_mask, config=self.hparam, use_fp16=self.hparam['fp16'] ) cls_feats = bert.get_pooled_output() cls_feats = L.dropout( x=cls_feats, dropout_prob=0.1, dropout_implementation="upscale_in_train" ) logits = L.fc( input=cls_feats, size=self.hparam['num_label'], param_attr=F.ParamAttr( name="cls_out_w", initializer=F.initializer.TruncatedNormal(scale=0.02)), bias_attr=F.ParamAttr( name="cls_out_b", initializer=F.initializer.Constant(0.)) ) propeller.summary.histogram('pred', logits) if self.mode is propeller.RunMode.PREDICT: probs = L.softmax(logits) return probs else: return logits
def sag_pool(gw, feature, ratio, graph_id, dataset, name, activation=L.tanh): """Implementation of self-attention graph pooling (SAGPool) This is an implementation of the paper SELF-ATTENTION GRAPH POOLING (https://arxiv.org/pdf/1904.08082.pdf) Args: gw: Graph wrapper object. feature: A tensor with shape (num_nodes, feature_size). ratio: The pooling ratio of nodes we want to select. graph_id: The graphs that the nodes belong to. dataset: To differentiate FRANKENSTEIN dataset and other datasets. name: The name of SAGPool layer. activation: The activation function. Return: new_feature: A tensor with shape (num_nodes, feature_size), and the unselected nodes' feature is masked by zero. ratio_length: The selected node numbers of each graph. """ if dataset == "FRANKENSTEIN": gcn_ = gcn else: gcn_ = norm_gcn score = gcn_(gw=gw, feature=feature, hidden_size=1, activation=None, norm=gw.node_feat["norm"], name=name) score = L.squeeze(score, axes=[]) perm, ratio_length = topk_pool(gw, score, graph_id, ratio) mask = L.zeros_like(score) mask = L.cast(mask, dtype="float32") updates = L.ones_like(perm) updates = L.cast(updates, dtype="float32") mask = L.scatter(mask, perm, updates) new_feature = L.elementwise_mul(feature, mask, axis=0) temp_score = activation(score) new_feature = L.elementwise_mul(new_feature, temp_score, axis=0) return new_feature, ratio_length
def pop(cls, stack_data, mask=True, in_place=True): """pop data in stack_data Args: stack_data (StackData): (data, pos) with shape ([batch_size, stack_len], [batch_size, 1]) mask (bool): 是否 mask 空栈的返回值。默认为 True in_place (bool): 默认为 True Returns: (Variable1, Variable2) Variable1: pop 得到的值 dtype=stack_data.data.dtype shape=[-1] Variable2: 对应位置的值是否合法。入参已经为空的栈,此处为 False。 dtype=bool shape=[-1] Raises: NULL """ data = stack_data.data pos = stack_data.pos # 只有非空的栈才能pop(才合法) valid_pos = layers.logical_not(cls.empty(stack_data)) new_pos_delta = layers.cast(valid_pos, dtype=pos.dtype) new_pos = layers.elementwise_sub(pos, new_pos_delta) # shape = [batch_size] output = nn_utils.batch_gather(data, new_pos) # mask 空栈的返回值 if mask: # shape = [batch_size, 1] mask_tag = layers.cast( new_pos_delta, dtype=data.dtype) if data.dtype != pos.dtype else new_pos_delta mask_tag = layers.squeeze(mask_tag, [1]) output = layers.elementwise_mul(output, mask_tag) # 出栈后原位置置为0 updates = layers.zeros_like(output) new_data = nn_utils.batch_scatter(data, new_pos, updates, overwrite=True, in_place=in_place) if in_place: layers.assign(new_pos, pos) return output, valid_pos, stack_data else: return output, valid_pos, StackData(new_data, new_pos)
def increment(self): enough_steps = layers.less_than(self.increment_every, self.good_steps + 1) with layers.Switch() as switch: with switch.case(enough_steps): new_scale = self.scale * self.factor scale_valid = layers.isfinite(new_scale) with layers.Switch() as switch2: with switch2.case(scale_valid): layers.assign(new_scale, self.scale) layers.assign(layers.zeros_like(self.good_steps), self.good_steps) with switch2.default(): layers.increment(self.good_steps) with switch.default(): layers.increment(self.good_steps)
def empty(cls, stack_data, dtype='bool'): """Return True if stack is empty(pos == 0) Args: stack_data (TYPE): NULL dtype (str): result dtype. Default is bool. Returns: Variable shape=[-1], dtype=params<dtype> Raises: NULL """ zeros = layers.zeros_like(stack_data.pos) output = layers.equal(stack_data.pos, zeros) if dtype != 'bool': output = layers.cast(output, dtype=dtype) return output
def ernie_send(src_feat, dst_feat, edge_feat): """doc""" cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1) src_ids = L.concat([cls, src_feat["term_ids"]], 1) dst_ids = dst_feat["term_ids"] sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1) term_ids = L.concat([src_ids, dst_ids], 1) term_ids.stop_gradient = True sent_ids.stop_gradient = True ernie = ErnieModel(term_ids, sent_ids, config=self.config.ernie_config) feature = ernie.get_pooled_output() return feature
def _push_to_stack(gmr_desc, gmr_pos, gmr_lens, gmr_stack_info): """push grammar id in gmr_desc from gmr_pos to gmr_lens to gmr_stack. and update step_gmr_pos Args: gmr_desc (TYPE): NULL gmr_pos (TYPE): NULL gmr_lens (TYPE): NULL gmr_stack_info (tuple): [in/out] (gmr_stack, gmr_stack_pos) Returns: tuple (gmr_stack, gmr_stack_pos) Raises: NULL """ gmr_stack, gmr_stack_pos = gmr_stack_info mv_step = layers.cast(layers.greater_than(gmr_lens, layers.zeros_like(gmr_lens)), dtype=gmr_lens.dtype) gmr_mv_pos = layers.elementwise_sub(gmr_lens, mv_step) cond = layers.reduce_any(layers.greater_than(gmr_mv_pos, gmr_pos)) while_op = layers.While(cond) with while_op.block(): gmr_ids = nn_utils.batch_gather(gmr_desc, gmr_mv_pos) gmr_stack_tmp, gmr_stack_pos_tmp = data_structure.Stack.push( gmr_stack_info, gmr_ids, in_place=False) mv_cond = layers.greater_than(gmr_mv_pos, gmr_pos) gmr_mv_pos_tmp = fluider.elementwise_sub(gmr_mv_pos, mv_cond, force=True) new_gmr_stack, new_gmr_stack_pos = nn_utils.ifelse( mv_cond, [gmr_stack_tmp, gmr_stack_pos_tmp], [gmr_stack, gmr_stack_pos]) layers.utils.map_structure(layers.assign, [new_gmr_stack, new_gmr_stack_pos], [gmr_stack, gmr_stack_pos]) layers.assign(gmr_mv_pos_tmp, gmr_mv_pos) layers.assign( layers.reduce_any(layers.greater_than(gmr_mv_pos, gmr_pos)), cond) return gmr_stack, gmr_stack_pos
def forward(self, features): src_ids, sent_ids, input_seqlen = features zero = L.fill_constant([1], dtype='int64', value=0) input_mask = L.cast(L.equal(src_ids, zero), 'float32') # assume pad id == 0 #input_mask = L.unsqueeze(input_mask, axes=[2]) d_shape = L.shape(src_ids) seqlen = d_shape[1] batch_size = d_shape[0] pos_ids = L.unsqueeze(L.range(0, seqlen, 1, dtype='int32'), axes=[0]) pos_ids = L.expand(pos_ids, [batch_size, 1]) pos_ids = L.unsqueeze(pos_ids, axes=[2]) pos_ids = L.cast(pos_ids, 'int64') pos_ids.stop_gradient = True input_mask.stop_gradient = True task_ids = L.zeros_like( src_ids) + self.hparam.task_id #this shit wont use at the moment task_ids.stop_gradient = True model = ErnieModel(src_ids=src_ids, position_ids=pos_ids, sentence_ids=sent_ids, task_ids=task_ids, input_mask=input_mask, config=self.hparam, use_fp16=self.hparam['use_fp16']) enc_out = model.get_sequence_output() logits = L.fc( input=enc_out, size=self.num_label, num_flatten_dims=2, param_attr=F.ParamAttr( name="cls_seq_label_out_w", initializer=F.initializer.TruncatedNormal(scale=0.02)), bias_attr=F.ParamAttr(name="cls_seq_label_out_b", initializer=F.initializer.Constant(0.))) propeller.summary.histogram('pred', logits) return logits, input_seqlen
def unpool(value): """Unpooling operation. N-dimensional version of the unpooling operation from https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf Taken from: https://github.com/tensorflow/tensorflow/issues/2169 Args: value: a Tensor of shape [b, d0, d1, ..., dn, ch] name: name of the op Returns: A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch] """ value = layers.transpose(value, [0, 2, 3, 1]) sh = value.shape dim = len(sh[1:-1]) out = (layers.reshape(value, [-1] + sh[-dim:])) for i in range(dim, 0, -1): out = layers.concat([out, layers.zeros_like(out)], i) out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]] out = layers.reshape(out, out_size) out = layers.transpose(out, [0, 3, 1, 2]) return out
def backward(self, loss, **kwargs): state = mixed_precision_global_state() callbacks = 'callbacks' in kwargs and kwargs['callbacks'] or None if callbacks is None: from paddle.fluid.clip import error_clip_callback callbacks = [error_clip_callback] # XXX what if gradient is zero? if state is not None: kwargs['callbacks'] = [scale_gradient] + callbacks else: kwargs['callbacks'] = callbacks param_grads = self._backward(loss, **kwargs) if state is not None: grad_valid = update_loss_scale(v for k, v in param_grads) if state.dynamic_scaling: with layers.Switch() as switch: with switch.case(grad_valid): pass with switch.default(): for _, g in param_grads: layers.assign(layers.zeros_like(g), g) return param_grads
def ernie_send(src_feat, dst_feat, edge_feat): def build_position_ids(term_ids): input_mask = L.cast(term_ids > 0, "int64") position_ids = L.cumsum(input_mask, axis=1) - 1 return position_ids """doc""" # input_ids cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1], "int64", self.config.cls_id) src_ids = L.concat([cls, src_feat["term_ids"]], 1) dst_ids = dst_feat["term_ids"] # sent_ids sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1) term_ids = L.concat([src_ids, dst_ids], 1) # position_ids position_ids = build_position_ids(term_ids) ernie_model = ErnieModel(self.config.ernie_config, "") feature, _ = ernie_model(term_ids, sent_ids, position_ids) return feature
def train(self): place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() with fluid.dygraph.guard(place): self.genA2B.train() self.genB2A.train() self.disGA.train() self.disGB.train() self.disLA.train() self.disLB.train() if self.resume: files_list = os.listdir(self.model_path) if len(files_list) > 0: files = [] print("exist files") for i in files_list: file_ = os.path.splitext(i)[1] files.append(file_) if ".pdparams" in files_list or ".pdopt" in files_list: print("exist model") genA2B_para = fluid.load_dygraph(self.model_path + 'g_A2B') genB2A_para = fluid.load_dygraph(self.model_path + 'g_B2A') disGA_para = fluid.load_dygraph(self.model_path + 'd_GA') disGB_para = fluid.load_dygraph(self.model_path + 'd_GB') disLA_para = fluid.load_dygraph(self.model_path + 'd_LA') disLB_para = fluid.load_dygraph(self.model_path + 'd_LB') G_opt = fluid.load_dygraph(self.model_path + 'G_op') D_opt = fluid.load_dygraph(self.model_path + 'D_op') self.genA2B.load_dict(genA2B_para) self.genB2A.load_dict(genB2A_para) self.disGA.load_dict(disGA_para) self.disGB.load_dict(disGB_para) self.disLA.load_dict(disLA_para) self.disLB.load_dict(disLB_para) self.G_optim.set_dict(G_opt) self.D_optim.set_dict(D_opt) print(" [*] Load SUCCESS") else: print(" No Model!") else: print("No Files") # training loop print('training start !') start_iter = 1 for step in range(start_iter, self.iteration + 1): trainA_iter = iter(self.trainA_loader()) real_A = next(trainA_iter) real_A = paddle.fluid.dygraph.to_variable(np.array(real_A)) real_A = real_A / 255.0 trainB_iter = iter(self.trainB_loader()) real_B = next(trainB_iter) real_B = paddle.fluid.dygraph.to_variable(np.array(real_B)) real_B = real_B / 255.0 # Update D self.D_optim.clear_gradients() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, fluid.dygraph.to_variable( ones_like(real_GA_logit))) + self.MSE_loss( fake_GA_logit, fluid.dygraph.to_variable( zeros_like(fake_GA_logit))) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, fluid.dygraph.to_variable( ones_like(real_GA_cam_logit))) + self.MSE_loss( fake_GA_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_GA_cam_logit))) D_ad_loss_LA = self.MSE_loss( real_LA_logit, fluid.dygraph.to_variable( ones_like(real_LA_logit))) + self.MSE_loss( fake_LA_logit, fluid.dygraph.to_variable( zeros_like(fake_LA_logit))) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, fluid.dygraph.to_variable( ones_like(real_LA_cam_logit))) + self.MSE_loss( fake_LA_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_LA_cam_logit))) D_ad_loss_GB = self.MSE_loss( real_GB_logit, fluid.dygraph.to_variable( ones_like(real_GB_logit))) + self.MSE_loss( fake_GB_logit, fluid.dygraph.to_variable( zeros_like(fake_GB_logit))) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, fluid.dygraph.to_variable( ones_like(real_GB_cam_logit))) + self.MSE_loss( fake_GB_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_GB_cam_logit))) D_ad_loss_LB = self.MSE_loss( real_LB_logit, fluid.dygraph.to_variable( ones_like(real_LB_logit))) + self.MSE_loss( fake_LB_logit, fluid.dygraph.to_variable( zeros_like(fake_LB_logit))) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, fluid.dygraph.to_variable( ones_like(real_LB_cam_logit))) + self.MSE_loss( fake_LB_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_LB_cam_logit))) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) # Update G self.G_optim.clear_gradients() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss( fake_GA_logit, fluid.dygraph.to_variable(ones_like(fake_GA_logit))) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, fluid.dygraph.to_variable(ones_like(fake_GA_cam_logit))) G_ad_loss_LA = self.MSE_loss( fake_LA_logit, fluid.dygraph.to_variable(ones_like(fake_LA_logit))) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, fluid.dygraph.to_variable(ones_like(fake_LA_cam_logit))) G_ad_loss_GB = self.MSE_loss( fake_GB_logit, fluid.dygraph.to_variable(ones_like(fake_GB_logit))) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, fluid.dygraph.to_variable(ones_like(fake_GB_cam_logit))) G_ad_loss_LB = self.MSE_loss( fake_LB_logit, fluid.dygraph.to_variable(ones_like(fake_LB_logit))) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, fluid.dygraph.to_variable(ones_like(fake_LB_cam_logit))) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, fluid.dygraph.to_variable( ones_like(fake_B2A_cam_logit))) + self.BCE_loss( fake_A2A_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_A2A_cam_logit))) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, fluid.dygraph.to_variable( ones_like(fake_A2B_cam_logit))) + self.BCE_loss( fake_B2B_cam_logit, fluid.dygraph.to_variable( zeros_like(fake_B2B_cam_logit))) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) # clip parameter of AdaILN and ILN, applied after optimizer step clip_rho(self.genA2B, vmin=0, vmax=1) clip_rho(self.genB2A, vmin=0, vmax=1) if step % 50 == 0: print("[%5d/%5d] d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: print("print img!") train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): trainA_iter = iter(self.trainA_loader()) real_A = next(trainA_iter) real_A = paddle.fluid.dygraph.to_variable( np.array(real_A)) real_A = real_A / 255.0 trainB_iter = iter(self.trainB_loader()) real_B = next(trainB_iter) real_B = paddle.fluid.dygraph.to_variable( np.array(real_B)) real_B = real_B / 255.0 fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A( fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B( fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): testA_iter = iter(self.testA_loader()) real_A = next(testA_iter) real_A = paddle.fluid.dygraph.to_variable( np.array(real_A)) real_A = real_A / 255.0 testB_iter = iter(self.testB_loader()) real_B = next(testB_iter) real_B = paddle.fluid.dygraph.to_variable( np.array(real_B)) real_B = real_B / 255.0 fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A( fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B( fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm( fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, 'B2A_%07d.png' % step), B2A * 255.0) if step % self.save_freq == 0: fluid.save_dygraph(self.genA2B.state_dict(), self.model_path + 'g_A2B') fluid.save_dygraph(self.genB2A.state_dict(), self.model_path + 'g_B2A') fluid.save_dygraph(self.disGA.state_dict(), self.model_path + 'd_GA') fluid.save_dygraph(self.disGB.state_dict(), self.model_path + 'd_GB') fluid.save_dygraph(self.disLA.state_dict(), self.model_path + 'd_LA') fluid.save_dygraph(self.disLB.state_dict(), self.model_path + 'd_LB') fluid.save_dygraph(self.G_optim.state_dict(), self.model_path + 'g_A2B') fluid.save_dygraph(self.G_optim.state_dict(), self.model_path + 'g_B2A') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_GA') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_GB') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_LA') fluid.save_dygraph(self.D_optim.state_dict(), self.model_path + 'd_LB')
def forward(self, src_ids, sent_ids=None, pos_ids=None, input_mask=None, attn_bias=None, past_cache=None, use_causal_mask=False): """ Args: src_ids (`Variable` of shape `[batch_size, seq_len]`): Indices of input sequence tokens in the vocabulary. sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`): aka token_type_ids, Segment token indices to indicate first and second portions of the inputs. if None, assume all tokens come from `segment_a` pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`): Indices of positions of each input sequence tokens in the position embeddings. input_mask(optional `Variable` of shape `[batch_size, seq_len]`): Mask to avoid performing attention on the padding token indices of the encoder input. attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`): 3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask past_cache(optional, tuple of two lists: cached key and cached value, each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`): cached key/value tensor that will be concated to generated key/value when performing self attention. if set, `attn_bias` should not be None. Returns: pooled (`Variable` of shape `[batch_size, hidden_size]`): output logits of pooler classifier encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`): output logits of transformer stack """ assert len( src_ids.shape ) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (repr( src_ids.shape)) assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None' d_batch = L.shape(src_ids)[0] d_seqlen = L.shape(src_ids)[1] if pos_ids is None: pos_ids = L.reshape(L.range(0, d_seqlen, 1, dtype='int32'), [1, -1]) pos_ids = L.cast(pos_ids, 'int64') if attn_bias is None: if input_mask is None: input_mask = L.cast(src_ids != 0, 'float32') assert len(input_mask.shape) == 2 input_mask = L.unsqueeze(input_mask, axes=[-1]) attn_bias = L.matmul(input_mask, input_mask, transpose_y=True) if use_causal_mask: sequence = L.reshape( L.range(0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1]) causal_mask = L.cast((L.matmul( sequence, 1. / sequence, transpose_y=True) >= 1.), 'float32') attn_bias *= causal_mask else: assert len( attn_bias.shape ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape attn_bias = (1. - attn_bias) * -10000.0 attn_bias = L.unsqueeze(attn_bias, [1]) attn_bias = L.expand(attn_bias, [1, self.n_head, 1, 1]) # avoid broadcast =_= attn_bias.stop_gradient = True if sent_ids is None: sent_ids = L.zeros_like(src_ids) src_embedded = self.word_emb(src_ids) pos_embedded = self.pos_emb(pos_ids) sent_embedded = self.sent_emb(sent_ids) embedded = src_embedded + pos_embedded + sent_embedded embedded = self.dropout(self.ln(embedded)) encoded, hidden_list, cache_list = self.encoder_stack( embedded, attn_bias, past_cache=past_cache) if self.pooler is not None: pooled = self.pooler(encoded[:, 0, :]) else: pooled = None additional_info = { 'hiddens': hidden_list, 'caches': cache_list, } if self.return_additional_info: return pooled, encoded, additional_info else: return pooled, encoded
def zero_grad(): for _, g in param_grads: layers.assign(layers.zeros_like(g), g)
def beam_search(enc_output, enc_bias, source_length): """ beam_search """ max_len = layers.fill_constant( shape=[1], dtype='int64', value=max_out_len) step_idx = layers.fill_constant( shape=[1], dtype='int64', value=0) cond = layers.less_than(x=step_idx, y=max_len) while_op = layers.While(cond) caches_batch_size = batch_size * beam_size init_score = np.zeros([1, beam_size]).astype('float32') init_score[:, 1:] = -INF initial_log_probs = layers.assign(init_score) alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1]) # alive seq [batch_size, beam_size, 1] initial_ids = layers.zeros([batch_size, 1, 1], 'float32') alive_seq = layers.expand(initial_ids, [1, beam_size, 1]) alive_seq = layers.cast(alive_seq, 'int64') enc_output = layers.unsqueeze(enc_output, axes=[1]) enc_output = layers.expand(enc_output, [1, beam_size, 1, 1]) enc_output = layers.reshape(enc_output, [caches_batch_size, -1, d_model]) tgt_src_attn_bias = layers.unsqueeze(enc_bias, axes=[1]) tgt_src_attn_bias = layers.expand(tgt_src_attn_bias, [1, beam_size, n_head, 1, 1]) enc_bias_shape = layers.shape(tgt_src_attn_bias) tgt_src_attn_bias = layers.reshape(tgt_src_attn_bias, [-1, enc_bias_shape[2], enc_bias_shape[3], enc_bias_shape[4]]) beam_search = BeamSearch(beam_size, batch_size, decode_alpha, trg_vocab_size, d_model) caches = [{ "k": layers.fill_constant( shape=[caches_batch_size, 0, d_model], dtype=enc_output.dtype, value=0), "v": layers.fill_constant( shape=[caches_batch_size, 0, d_model], dtype=enc_output.dtype, value=0) } for i in range(n_layer)] finished_seq = layers.zeros_like(alive_seq) finished_scores = layers.fill_constant([batch_size, beam_size], dtype='float32', value=-INF) finished_flags = layers.fill_constant([batch_size, beam_size], dtype='float32', value=0) with while_op.block(): pos = layers.fill_constant([caches_batch_size, 1, 1], dtype='int64', value=1) pos = layers.elementwise_mul(pos, step_idx, axis=0) alive_seq_1 = layers.reshape(alive_seq, [caches_batch_size, -1]) alive_seq_2 = alive_seq_1[:, -1:] alive_seq_2 = layers.unsqueeze(alive_seq_2, axes=[1]) logits = wrap_decoder( trg_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing, embedding_sharing, dec_inputs=(alive_seq_2, alive_seq_2, pos, None, tgt_src_attn_bias), enc_output=enc_output, caches=caches, is_train=False, params_type=params_type) alive_seq_2, alive_log_probs_2, finished_seq_2, finished_scores_2, finished_flags_2, caches_2 = \ beam_search.inner_func(step_idx, logits, alive_seq_1, alive_log_probs, finished_seq, finished_scores, finished_flags, caches, enc_output, tgt_src_attn_bias) layers.increment(x=step_idx, value=1.0, in_place=True) finish_cond = beam_search.is_finished(step_idx, source_length, alive_log_probs_2, finished_scores_2, finished_flags_2) layers.assign(alive_seq_2, alive_seq) layers.assign(alive_log_probs_2, alive_log_probs) layers.assign(finished_seq_2, finished_seq) layers.assign(finished_scores_2, finished_scores) layers.assign(finished_flags_2, finished_flags) for i in xrange(len(caches_2)): layers.assign(caches_2[i]["k"], caches[i]["k"]) layers.assign(caches_2[i]["v"], caches[i]["v"]) layers.logical_and(x=cond, y=finish_cond, out=cond) finished_flags = layers.reduce_sum(finished_flags, dim=1, keep_dim=True) / beam_size finished_flags = layers.cast(finished_flags, 'bool') mask = layers.cast(layers.reduce_any(input=finished_flags, dim=1, keep_dim=True), 'float32') mask = layers.expand(mask, [1, beam_size]) mask2 = 1.0 - mask finished_seq = layers.cast(finished_seq, 'float32') alive_seq = layers.cast(alive_seq, 'float32') #print mask finished_seq = layers.elementwise_mul(finished_seq, mask, axis=0) + \ layers.elementwise_mul(alive_seq, mask2, axis = 0) finished_seq = layers.cast(finished_seq, 'int32') finished_scores = layers.elementwise_mul(finished_scores, mask, axis=0) + \ layers.elementwise_mul(alive_log_probs, mask2) finished_seq.persistable = True finished_scores.persistable = True return finished_seq, finished_scores
def beam_search(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias, bos_id=0, eos_id=1, beam_size=4, max_len=256): def expand_to_beam_size(tensor, beam_size): tensor = layers.reshape(tensor, [tensor.shape[0], 1] + tensor.shape[1:]) tile_dims = [1] * len(tensor.shape) tile_dims[1] = beam_size return layers.expand(tensor, tile_dims) def merge_batch_beams(tensor): return layers.reshape(tensor, [tensor.shape[0] * tensor.shape[1]] + tensor.shape[2:]) def split_batch_beams(tensor): return fluid.layers.reshape(tensor, shape=[-1, beam_size] + list(tensor.shape[1:])) def mask_probs(probs, finished, noend_mask_tensor): # TODO: use where_op finished = layers.cast(finished, dtype=probs.dtype) probs = layers.elementwise_mul(layers.expand( layers.unsqueeze(finished, [2]), [1, 1, self.trg_vocab_size]), noend_mask_tensor, axis=-1) - layers.elementwise_mul( probs, (finished - 1), axis=0) return probs def gather(x, indices, batch_pos): topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2) return layers.gather_nd(x, topk_coordinates) # run encoder enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) # constant number inf = float(1. * 1e7) batch_size = enc_output.shape[0] max_len = (enc_output.shape[1] + 20) if max_len is None else max_len vocab_size_tensor = layers.fill_constant(shape=[1], dtype="int64", value=self.trg_vocab_size) end_token_tensor = to_variable( np.full([batch_size, beam_size], eos_id, dtype="int64")) noend_array = [-inf] * self.trg_vocab_size noend_array[eos_id] = 0 noend_mask_tensor = to_variable(np.array(noend_array, dtype="float32")) batch_pos = layers.expand( layers.unsqueeze( to_variable(np.arange(0, batch_size, 1, dtype="int64")), [1]), [1, beam_size]) predict_ids = [] parent_ids = [] ### initialize states of beam search ### log_probs = to_variable( np.array([[0.] + [-inf] * (beam_size - 1)] * batch_size, dtype="float32")) finished = to_variable( np.full([batch_size, beam_size], 0, dtype="bool")) ### initialize inputs and states of transformer decoder ### ## init inputs for decoder, shaped `[batch_size*beam_size, ...]` trg_word = layers.fill_constant(shape=[batch_size * beam_size, 1], dtype="int64", value=bos_id) trg_pos = layers.zeros_like(trg_word) trg_src_attn_bias = merge_batch_beams( expand_to_beam_size(trg_src_attn_bias, beam_size)) enc_output = merge_batch_beams( expand_to_beam_size(enc_output, beam_size)) ## init states (caches) for transformer, need to be updated according to selected beam caches = [{ "k": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_key], dtype=enc_output.dtype, value=0), "v": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_value], dtype=enc_output.dtype, value=0), } for i in range(self.n_layer)] for i in range(max_len): trg_pos = layers.fill_constant(shape=trg_word.shape, dtype="int64", value=i) caches = map_structure( # can not be reshaped since the 0 size lambda x: x if i == 0 else merge_batch_beams(x), caches) logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output, caches) caches = map_structure(split_batch_beams, caches) step_log_probs = split_batch_beams( fluid.layers.log(fluid.layers.softmax(logits))) step_log_probs = mask_probs(step_log_probs, finished, noend_mask_tensor) log_probs = layers.elementwise_add(x=step_log_probs, y=log_probs, axis=0) log_probs = layers.reshape(log_probs, [-1, beam_size * self.trg_vocab_size]) scores = log_probs topk_scores, topk_indices = fluid.layers.topk(input=scores, k=beam_size) beam_indices = fluid.layers.elementwise_floordiv( topk_indices, vocab_size_tensor) token_indices = fluid.layers.elementwise_mod( topk_indices, vocab_size_tensor) # update states caches = map_structure( lambda x: gather(x, beam_indices, batch_pos), caches) log_probs = gather(log_probs, topk_indices, batch_pos) finished = gather(finished, beam_indices, batch_pos) finished = layers.logical_or( finished, layers.equal(token_indices, end_token_tensor)) trg_word = layers.reshape(token_indices, [-1, 1]) predict_ids.append(token_indices) parent_ids.append(beam_indices) if layers.reduce_all(finished).numpy(): break predict_ids = layers.stack(predict_ids, axis=0) parent_ids = layers.stack(parent_ids, axis=0) finished_seq = layers.transpose( layers.gather_tree(predict_ids, parent_ids), [1, 2, 0]) finished_scores = topk_scores return finished_seq, finished_scores
def beam_search_v2(self, src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias, bos_id=0, eos_id=1, beam_size=4, max_len=None, alpha=0.6): """ Beam search with the alive and finished two queues, both have a beam size capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as steps. 1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting EOS. 2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs of next decoding step. 3. `grow_finish` compares the already finished candidates in the finished queue and newly added finished candidates from `grow_topk`, and selects the top `beam_size` finished candidates. """ def expand_to_beam_size(tensor, beam_size): tensor = layers.reshape(tensor, [tensor.shape[0], 1] + tensor.shape[1:]) tile_dims = [1] * len(tensor.shape) tile_dims[1] = beam_size return layers.expand(tensor, tile_dims) def merge_beam_dim(tensor): return layers.reshape(tensor, [-1] + tensor.shape[2:]) # run encoder enc_output = self.encoder(src_word, src_pos, src_slf_attn_bias) # constant number inf = float(1. * 1e7) batch_size = enc_output.shape[0] max_len = (enc_output.shape[1] + 20) if max_len is None else max_len ### initialize states of beam search ### ## init for the alive ## initial_log_probs = to_variable( np.array([[0.] + [-inf] * (beam_size - 1)], dtype="float32")) alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1]) alive_seq = to_variable( np.tile(np.array([[[bos_id]]], dtype="int64"), (batch_size, beam_size, 1))) ## init for the finished ## finished_scores = to_variable( np.array([[-inf] * beam_size], dtype="float32")) finished_scores = layers.expand(finished_scores, [batch_size, 1]) finished_seq = to_variable( np.tile(np.array([[[bos_id]]], dtype="int64"), (batch_size, beam_size, 1))) finished_flags = layers.zeros_like(finished_scores) ### initialize inputs and states of transformer decoder ### ## init inputs for decoder, shaped `[batch_size*beam_size, ...]` trg_word = layers.reshape(alive_seq[:, :, -1], [batch_size * beam_size, 1]) trg_src_attn_bias = merge_beam_dim( expand_to_beam_size(trg_src_attn_bias, beam_size)) enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size)) ## init states (caches) for transformer, need to be updated according to selected beam caches = [{ "k": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_key], dtype=enc_output.dtype, value=0), "v": layers.fill_constant( shape=[batch_size * beam_size, self.n_head, 0, self.d_value], dtype=enc_output.dtype, value=0), } for i in range(self.n_layer)] def update_states(caches, beam_idx, beam_size): for cache in caches: cache["k"] = gather_2d_by_gather(cache["k"], beam_idx, beam_size, batch_size, False) cache["v"] = gather_2d_by_gather(cache["v"], beam_idx, beam_size, batch_size, False) return caches def gather_2d_by_gather(tensor_nd, beam_idx, beam_size, batch_size, need_flat=True): batch_idx = layers.range(0, batch_size, 1, dtype="int64") * beam_size flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd idx = layers.reshape( layers.elementwise_add(beam_idx, batch_idx, 0), [-1]) new_flat_tensor = layers.gather(flat_tensor, idx) new_tensor_nd = layers.reshape( new_flat_tensor, shape=[batch_size, beam_idx.shape[1]] + tensor_nd.shape[2:]) if need_flat else new_flat_tensor return new_tensor_nd def early_finish(alive_log_probs, finished_scores, finished_in_finished): max_length_penalty = np.power(((5. + max_len) / 6.), alpha) # The best possible score of the most likely alive sequence lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty # Now to compute the lowest score of a finished sequence in finished # If the sequence isn't finished, we multiply it's score by 0. since # scores are all -ve, taking the min will give us the score of the lowest # finished item. lowest_score_of_fininshed_in_finished = layers.reduce_min( finished_scores * finished_in_finished, 1) # If none of the sequences have finished, then the min will be 0 and # we have to replace it by -ve INF if it is. The score of any seq in alive # will be much higher than -ve INF and the termination condition will not # be met. lowest_score_of_fininshed_in_finished += ( 1. - layers.reduce_max(finished_in_finished, 1)) * -inf bound_is_met = layers.reduce_all( layers.greater_than(lowest_score_of_fininshed_in_finished, lower_bound_alive_scores)) return bound_is_met def grow_topk(i, logits, alive_seq, alive_log_probs, states): logits = layers.reshape(logits, [batch_size, beam_size, -1]) candidate_log_probs = layers.log(layers.softmax(logits, axis=2)) log_probs = layers.elementwise_add(candidate_log_probs, alive_log_probs, 0) length_penalty = np.power(5.0 + (i + 1.0) / 6.0, alpha) curr_scores = log_probs / length_penalty flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1]) topk_scores, topk_ids = layers.topk(flat_curr_scores, k=beam_size * 2) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids // self.trg_vocab_size topk_ids = topk_ids % self.trg_vocab_size # use gather as gather_nd, TODO: use gather_nd topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index, beam_size, batch_size) topk_seq = layers.concat( [topk_seq, layers.reshape(topk_ids, topk_ids.shape + [1])], axis=2) states = update_states(states, topk_beam_index, beam_size) eos = layers.fill_constant(shape=topk_ids.shape, dtype="int64", value=eos_id) topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32") #topk_seq: [batch_size, 2*beam_size, i+1] #topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size] return topk_seq, topk_log_probs, topk_scores, topk_finished, states def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states): curr_scores += curr_finished * -inf _, topk_indexes = layers.topk(curr_scores, k=beam_size) alive_seq = gather_2d_by_gather(curr_seq, topk_indexes, beam_size * 2, batch_size) alive_log_probs = gather_2d_by_gather(curr_log_probs, topk_indexes, beam_size * 2, batch_size) states = update_states(states, topk_indexes, beam_size * 2) return alive_seq, alive_log_probs, states def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, curr_scores, curr_finished): # finished scores finished_seq = layers.concat([ finished_seq, layers.fill_constant(shape=[batch_size, beam_size, 1], dtype="int64", value=eos_id) ], axis=2) # Set the scores of the unfinished seq in curr_seq to large negative # values curr_scores += (1. - curr_finished) * -inf # concatenating the sequences and scores along beam axis curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1) curr_finished_scores = layers.concat( [finished_scores, curr_scores], axis=1) curr_finished_flags = layers.concat( [finished_flags, curr_finished], axis=1) _, topk_indexes = layers.topk(curr_finished_scores, k=beam_size) finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes, beam_size * 3, batch_size) finished_scores = gather_2d_by_gather(curr_finished_scores, topk_indexes, beam_size * 3, batch_size) finished_flags = gather_2d_by_gather(curr_finished_flags, topk_indexes, beam_size * 3, batch_size) return finished_seq, finished_scores, finished_flags for i in range(max_len): trg_pos = layers.fill_constant(shape=trg_word.shape, dtype="int64", value=i) logits = self.decoder(trg_word, trg_pos, None, trg_src_attn_bias, enc_output, caches) topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( i, logits, alive_seq, alive_log_probs, caches) alive_seq, alive_log_probs, states = grow_alive( topk_seq, topk_scores, topk_log_probs, topk_finished, states) finished_seq, finished_scores, finished_flags = grow_finished( finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, topk_finished) trg_word = layers.reshape(alive_seq[:, :, -1], [batch_size * beam_size, 1]) if early_finish(alive_log_probs, finished_scores, finished_flags).numpy(): break return finished_seq, finished_scores
def train(self): self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() start_iter = 1 # TODO 恢复训练还没研究过 # if self.resume: # # glob 返回符合xxxx.pt的文件路径 # model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) # if not len(model_list) == 0: # model_list.sort() # start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) # self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) # print(" [*] Load SUCCESS") # if self.decay_flag and start_iter > (self.iteration // 2): # self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) # self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) # training loop print('training start !') start_time = time.time() for step in range(start_iter, self.iteration + 1): # TODO decay # if self.decay_flag and step > (self.iteration // 2): # self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) # self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) try: real_A, _ = next(trainA_iter) except: trainA_iter = self.trainA_loader() real_A, _ = next(trainA_iter)[0] try: real_B, _ = next(trainB_iter) except: trainB_iter = self.trainB_loader() real_B, _ = next(trainB_iter)[0] # real_A, real_B = real_A, real_B # Update D self.D_optim.clear_gradients() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss( real_GA_logit, layers.ones_like(real_GA_logit)) + self.MSE_loss( fake_GA_logit, layers.zeros_like(fake_GA_logit)) D_ad_cam_loss_GA = self.MSE_loss( real_GA_cam_logit, layers.ones_like(real_GA_cam_logit)) + self.MSE_loss( fake_GA_cam_logit, layers.zeros_like(fake_GA_cam_logit)) D_ad_loss_LA = self.MSE_loss( real_LA_logit, layers.ones_like(real_LA_logit)) + self.MSE_loss( fake_LA_logit, layers.zeros_like(fake_LA_logit)) D_ad_cam_loss_LA = self.MSE_loss( real_LA_cam_logit, layers.ones_like(real_LA_cam_logit)) + self.MSE_loss( fake_LA_cam_logit, layers.zeros_like(fake_LA_cam_logit)) D_ad_loss_GB = self.MSE_loss( real_GB_logit, layers.ones_like(real_GB_logit)) + self.MSE_loss( fake_GB_logit, layers.zeros_like(fake_GB_logit)) D_ad_cam_loss_GB = self.MSE_loss( real_GB_cam_logit, layers.ones_like(real_GB_cam_logit)) + self.MSE_loss( fake_GB_cam_logit, layers.zeros_like(fake_GB_cam_logit)) D_ad_loss_LB = self.MSE_loss( real_LB_logit, layers.ones_like(real_LB_logit)) + self.MSE_loss( fake_LB_logit, layers.zeros_like(fake_LB_logit)) D_ad_cam_loss_LB = self.MSE_loss( real_LB_cam_logit, layers.ones_like(real_LB_cam_logit)) + self.MSE_loss( fake_LB_cam_logit, layers.zeros_like(fake_LB_cam_logit)) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.minimize(Discriminator_loss) # Update G self.G_optim.clear_gradients() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, layers.ones_like(fake_GA_logit)) G_ad_cam_loss_GA = self.MSE_loss( fake_GA_cam_logit, layers.ones_like(fake_GA_cam_logit)) G_ad_loss_LA = self.MSE_loss(fake_LA_logit, layers.ones_like(fake_LA_logit)) G_ad_cam_loss_LA = self.MSE_loss( fake_LA_cam_logit, layers.ones_like(fake_LA_cam_logit)) G_ad_loss_GB = self.MSE_loss(fake_GB_logit, layers.ones_like(fake_GB_logit)) G_ad_cam_loss_GB = self.MSE_loss( fake_GB_cam_logit, layers.ones_like(fake_GB_cam_logit)) G_ad_loss_LB = self.MSE_loss(fake_LB_logit, layers.ones_like(fake_LB_logit)) G_ad_cam_loss_LB = self.MSE_loss( fake_LB_cam_logit, layers.ones_like(fake_LB_cam_logit)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_cam_loss_A = self.BCE_loss( fake_B2A_cam_logit, layers.ones_like(fake_B2A_cam_logit), ) + self.BCE_loss(fake_A2A_cam_logit, layers.zeros_like(fake_A2A_cam_logit)) G_cam_loss_B = self.BCE_loss( fake_A2B_cam_logit, layers.ones_like(fake_A2B_cam_logit)) + self.BCE_loss( fake_B2B_cam_logit, layers.zeros_like(fake_B2B_cam_logit)) G_loss_A = self.adv_weight * ( G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA ) + self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + self.cam_weight * G_cam_loss_A G_loss_B = self.adv_weight * ( G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB ) + self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + self.cam_weight * G_cam_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.minimize(Generator_loss) # clip parameter of AdaILN and ILN, applied after optimizer step self.genA2B.apply(self.Rho_clipper) self.genB2A.apply(self.Rho_clipper) print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval( ), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() for _ in range(train_sample_num): try: real_A, _ = next(trainA_iter) except: trainA_iter = iter(self.trainA_loader) real_A, _ = next(trainA_iter) try: real_B, _ = next(trainB_iter) except: trainB_iter = iter(self.trainB_loader) real_B, _ = next(trainB_iter) real_A, real_B = real_A, real_B fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): try: real_A, _ = testA_iter.next() except: testA_iter = iter(self.testA_loader) real_A, _ = testA_iter.next() try: real_B, _ = testB_iter.next() except: testB_iter = iter(self.testB_loader) real_B, _ = testB_iter.next() real_A, real_B = real_A, real_B fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate( (A2B, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate( (B2A, np.concatenate( (RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite( os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train( ), self.disGB.train(), self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) if step % 1000 == 0: params = {} params['genA2B'] = self.genA2B.state_dict() params['genB2A'] = self.genB2A.state_dict() params['disGA'] = self.disGA.state_dict() params['disGB'] = self.disGB.state_dict() params['disLA'] = self.disLA.state_dict() params['disLB'] = self.disLB.state_dict() fluid.save_dygraph( params, os.path.join(self.result_dir, self.dataset + '_params_latest'))
def _ernie_model_forward(self, src_ids, sent_ids=None, pos_ids=None, input_mask=None, attn_bias=None, past_cache=None, use_causal_mask=False, num_layers=12, depth=1., head_mask=None): assert len( src_ids.shape ) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (repr( src_ids.shape)) assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None' d_batch = L.shape(src_ids)[0] d_seqlen = L.shape(src_ids)[1] if pos_ids is None: pos_ids = L.reshape(L.range(0, d_seqlen, 1, dtype='int32'), [1, -1]) pos_ids = L.cast(pos_ids, 'int64') if attn_bias is None: if input_mask is None: input_mask = L.cast(src_ids != 0, 'float32') assert len(input_mask.shape) == 2 input_mask = L.unsqueeze(input_mask, axes=[-1]) attn_bias = L.matmul(input_mask, input_mask, transpose_y=True) if use_causal_mask: sequence = L.reshape( L.range(0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1]) causal_mask = L.cast( (L.matmul(sequence, 1. / sequence, transpose_y=True) >= 1.), 'float32') attn_bias *= causal_mask else: assert len( attn_bias.shape ) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape attn_bias = (1. - attn_bias) * -10000.0 attn_bias = L.unsqueeze(attn_bias, [1]) attn_bias.stop_gradient = True if sent_ids is None: sent_ids = L.zeros_like(src_ids) if head_mask is not None: if len(head_mask.shape) == 1: head_mask = L.unsqueeze( L.unsqueeze(L.unsqueeze(L.unsqueeze(head_mask, 0), 0), -1), -1) head_mask = L.expand(head_mask, expand_times=[num_layers, 1, 1, 1, 1]) elif len(head_mask.shape) == 2: head_mask = L.unsqueeze(L.unsqueeze(L.unsqueeze(head_mask, 1), -1), -1) else: head_mask = [None] * num_layers src_embedded = self.word_emb(src_ids) pos_embedded = self.pos_emb(pos_ids) sent_embedded = self.sent_emb(sent_ids) embedded = src_embedded + pos_embedded + sent_embedded embedded = self.dropout(self.ln(embedded)) encoded, hidden_list, cache_list = self.encoder_stack( embedded, attn_bias, past_cache=past_cache, num_layers=num_layers, depth_mult=depth, head_mask=head_mask) if self.pooler is not None: pooled = self.pooler(encoded[:, 0, :]) else: pooled = None additional_info = { 'hiddens': hidden_list, 'caches': cache_list, } if self.return_additional_info: return pooled, encoded, additional_info else: return pooled, encoded