Пример #1
0
 def step(self):
     output = self.model(self.image)
     loss = self.criterion(output, self.target)
     self.optim.zero_grad()
     loss.backward()
     utils.average_gradients(self.model)
     self.optim.step()
     return {'loss': loss}
Пример #2
0
 def step(self):
     cmp_output = self.model(self.image_input, self.sparse_input)
     loss_flow = self.flow_criterion(cmp_output,
                                     self.flow_target) / self.world_size
     self.optim.zero_grad()
     loss_flow.backward()
     utils.average_gradients(self.model)
     self.optim.step()
     return {'loss_flow': loss_flow}
 def step(self):
     if self.use_rgb:
         output = self.model(torch.cat([self.mask, self.eraser], dim=1), self.rgb)
     else:
         output = self.model(torch.cat([self.mask, self.eraser], dim=1))
     loss = self.criterion(output, self.target, self.eraser.squeeze(1)) / self.world_size
     self.optim.zero_grad()
     loss.backward()
     utils.average_gradients(self.model)
     self.optim.step()
     return {'loss': loss}
Пример #4
0
 def step(self):
     if self.use_rgb:
         output = self.model(self.mask, self.rgb)
     else:
         output = self.model(self.mask)
     loss = self.criterion(output, self.target) / self.world_size
     self.optim.zero_grad()
     loss.backward()
     utils.average_gradients(self.model)
     self.optim.step()
     return {'loss': loss}
Пример #5
0
    def train(self, xs1, xs2, scores):
        global_step = tf.train.get_or_create_global_step()
        lr = noam_scheme(self.context.lr, global_step,
                         self.context.warmup_steps)
        optimizer = tf.train.AdamOptimizer(lr)
        gpus = get_available_gpus()

        if gpus:
            num_gpu = len(gpus)
            assert self.context.hparams.batch_size % num_gpu == 0

            xs1s, xs2s = tf.split(xs1, num_gpu, axis=0), tf.split(xs2,
                                                                  num_gpu,
                                                                  axis=0)
            scoress = tf.split(scores, num_gpu, axis=0)

            tower_grads = []
            losses = []
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                list_predictions = []
                for i in range(num_gpu):
                    with tf.device("/gpu:%d" % i):
                        with tf.name_scope("tower_%d" % i):
                            predictions = self._get_prediction(
                                xs1s[i], xs2s[i])
                            list_predictions.append(predictions)
                            # square loss
                            partial_loss = tf.reduce_sum(tf.squared_difference(
                                predictions, scoress[i]),
                                                         name="loss")
                            losses.append(partial_loss)
                            tf.get_variable_scope().reuse_variables()
                            grad = get_gradients_by_loss_and_optimizer(
                                partial_loss, optimizer)
                            tower_grads.append(grad)
                predictions = tf.concat(list_predictions, axis=0)
            loss = tf.reduce_mean(losses)
            grads_and_vars = average_gradients(tower_grads)
        else:
            predictions = self._get_prediction(xs1, xs2)
            loss = tf.reduce_sum(tf.squared_difference(predictions, scores),
                                 name="loss")
            grads_and_vars = get_gradients_by_loss_and_optimizer(
                loss, optimizer)
        train_op = optimizer.apply_gradients(grads_and_vars,
                                             global_step=global_step)

        for g, v in grads_and_vars:
            tf.summary.histogram(v.name, v)
            tf.summary.histogram(v.name + '_grad', g)
        tf.summary.scalar("pred_avg", tf.reduce_mean(predictions))
        tf.summary.scalar("label_avg", tf.reduce_mean(scores))

        tf.summary.scalar('lr', lr)
        tf.summary.scalar("loss", loss)
        tf.summary.scalar("global_step", global_step)

        summaries = tf.summary.merge_all()
        return loss, train_op, global_step, summaries
Пример #6
0
    def build_train_graph_multi_gpu(self):
        gpu_num = len(self.hparams.gpu_num)

        context_ph = tf.split(self.context_ph, gpu_num, 0)
        context_len_ph = tf.split(self.context_len_ph, gpu_num, 0)
        utterances_ph = tf.split(self.utterances_ph, gpu_num, 0)
        utterances_len_ph = tf.split(self.utterances_len_ph, gpu_num, 0)
        target_ph = tf.split(self.target_ph, gpu_num, 0)

        context_sentence_ph = tf.split(self.context_sentence_ph, gpu_num, 0)
        context_sentence_len_ph = tf.split(self.context_sentence_len_ph,
                                           gpu_num, 0)
        tot_context_len_ph = tf.split(self.tot_context_len_ph, gpu_num, 0)
        speaker_ph = tf.split(self.speaker_ph, gpu_num, 0)

        optimizer = tf.train.AdamOptimizer(self.hparams.learning_rate)

        tower_grads = []
        tot_losses = []
        tot_logits = []
        tot_labels = []

        for i, gpu_id in enumerate(self.hparams.gpu_num):
            with tf.device('/gpu:%d' % gpu_id):
                with tf.variable_scope("inference", reuse=tf.AUTO_REUSE):
                    logits = self._inference(
                        context_ph[i], context_len_ph[i], utterances_ph[i],
                        utterances_len_ph[i], context_sentence_ph[i],
                        context_sentence_len_ph[i], tot_context_len_ph[i],
                        speaker_ph[i])

                    loss_op = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        logits=logits,
                        labels=target_ph[i],
                        name="cross_entropy")
                    loss_op = tf.reduce_mean(loss_op,
                                             name="cross_entropy_mean")

                    tot_losses.append(loss_op)
                    tot_logits.append(logits)
                    tot_labels.append(target_ph[i])

                    grads = optimizer.compute_gradients(loss_op)

                    tower_grads.append(grads)
                    tf.get_variable_scope().reuse_variables()

        grads = average_gradients(tower_grads)
        self.loss_op = tf.divide(tf.add_n(tot_losses), gpu_num)
        self.logits = tf.concat(tot_logits, axis=0)
        tot_labels = tf.concat(tot_labels, axis=0)
        self.train_op = optimizer.apply_gradients(grads, self.global_step)

        eval = tf.nn.in_top_k(self.logits, tot_labels, 1)
        correct_count = tf.reduce_sum(tf.cast(eval, tf.int32))
        self.accuracy = tf.divide(correct_count, tf.shape(self.target_ph)[0])
        self.predictions = argsort(self.logits, axis=1, direction='DESCENDING')
Пример #7
0
    def train(self, inputs, targets):
        global_step = tf.train.get_or_create_global_step()
        lr = noam_scheme(self._context.lr, global_step,
                         self._context.warmup_steps)
        optimizer = tf.train.AdamOptimizer(lr)
        gpus = get_available_gpus()

        loss_func = self._loss_func_dict.get(self._context.loss_func,
                                             self._get_loss)
        if gpus:
            num_gpu = len(gpus)
            assert self._context.hparams.batch_size % num_gpu == 0

            partial_inputs = [[] for _ in range(num_gpu)]
            for input_tmp in inputs:
                input_tmps = tf.split(input_tmp, num_gpu, axis=0)
                for i in range(num_gpu):
                    partial_inputs[i].append(input_tmps[i])
            targetses = tf.split(targets, num_gpu, axis=0)

            tower_grads = []
            losses = []
            with tf.variable_scope(tf.get_variable_scope()) as scope:
                for i in range(num_gpu):
                    with tf.device("/gpu:%d" % i):
                        with tf.name_scope("tower_%d" % i):
                            partial_loss = loss_func(partial_inputs[i],
                                                     targetses[i])
                            losses.append(partial_loss)
                            tf.get_variable_scope().reuse_variables()
                            grad = get_gradients_by_loss_and_optimizer(
                                partial_loss, optimizer)
                            tower_grads.append(grad)
            loss = tf.reduce_mean(losses)
            grads_and_vars = average_gradients(tower_grads)
        else:
            loss = tf.reduce_mean(loss_func(inputs, targets))
            grads_and_vars = get_gradients_by_loss_and_optimizer(
                loss, optimizer)
        train_op = optimizer.apply_gradients(grads_and_vars,
                                             global_step=global_step)

        for g, v in grads_and_vars:
            if g is None:  # 无梯度
                continue
            tf.summary.histogram(v.name, v)
            tf.summary.histogram(v.name + '_grad', g)
        tf.summary.scalar("pred_avg", tf.reduce_mean(self._outputs))
        tf.summary.scalar("infr_avg", tf.reduce_mean(self._inferences))
        tf.summary.scalar("label_avg", tf.reduce_mean(targets))

        tf.summary.scalar('lr', lr)
        tf.summary.scalar("loss", loss)
        tf.summary.scalar("global_step", global_step)

        summaries = tf.summary.merge_all()
        return loss, train_op, global_step, summaries
Пример #8
0
    def build_train_model(self, reuse=None):
        """Build model for training. """
        logging.info('Build train model.')
        self.prepare_training()

        def choose_device(op, device):
            if op.type.startswith('Variable'):
                return self._sync_device
            return device

        with self.graph.as_default(), tf.device(self._sync_device), \
            tf.variable_scope(tf.get_variable_scope(), initializer=self._initializer, reuse=reuse):
            Xs = split_tensor(self.src_pl, len(self._devices))
            Ys = split_tensor(self.dst_pl, len(self._devices))
            acc_list, loss_list, gv_list = [], [], []
            for i, (X, Y, device) in enumerate(zip(Xs, Ys, self._devices)):
                with tf.device(lambda op: choose_device(op, device)):
                    logging.info('Build model on %s.' % device)
                    encoder_output = self.encoder(X,
                                                  is_training=True,
                                                  reuse=i > 0 or None)
                    decoder_output = self.decoder(shift_right(Y),
                                                  encoder_output,
                                                  is_training=True,
                                                  reuse=i > 0 or None)
                    acc, loss = self.train_output(decoder_output,
                                                  Y,
                                                  reuse=i > 0 or None)
                    acc_list.append(acc)
                    loss_list.append(loss)
                    gv_list.append(self._optimizer.compute_gradients(loss))

            self.accuracy = tf.reduce_mean(acc_list)
            self.loss = tf.reduce_mean(loss_list)

            # Clip gradients and then apply.
            grads_and_vars = average_gradients(gv_list)
            for g, v in grads_and_vars:
                tf.summary.histogram('variables/' + v.name.split(':')[0], v)
                tf.summary.histogram('gradients/' + v.name.split(':')[0], g)
            grads, self.grads_norm = tf.clip_by_global_norm(
                [gv[0] for gv in grads_and_vars],
                clip_norm=self._config.train.grads_clip)
            grads_and_vars = zip(grads, [gv[1] for gv in grads_and_vars])
            self.train_op = self._optimizer.apply_gradients(
                grads_and_vars, global_step=self.global_step)

            # Summaries
            tf.summary.scalar('acc', self.accuracy)
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('learning_rate', self.learning_rate)
            tf.summary.scalar('grads_norm', self.grads_norm)
            self.summary_op = tf.summary.merge_all()

        # We may want to test the model during training.
        self.build_test_model(reuse=True)
Пример #9
0
    def update(e):
        obs, act, adv, ret, lgp_old = [
            torch.Tensor(x) for x in buf.retrieve_all()
        ]

        # Policy
        _, lgp, _ = ac.policy(obs, act)
        entropy = (-lgp).mean()

        # Policy loss # policy gradient term + entropy term
        pi_loss = -(lgp * adv).mean()

        # Train policy
        train_pi.zero_grad()
        pi_loss.backward()
        average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function
        v = ac.value_f(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            v = ac.value_f(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function train
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Log the changes
        _, lgp, _, v = ac(obs, act)
        entropy_new = (-lgp).mean()
        pi_loss_new = -(lgp * adv).mean()
        v_loss_new = F.mse_loss(v, ret)
        kl = (lgp_old - lgp).mean()
        logger.store(LossPi=pi_loss,
                     LossV=v_l_old,
                     DeltaLossPi=(pi_loss_new - pi_loss),
                     DeltaLossV=(v_loss_new - v_l_old),
                     Entropy=entropy,
                     KL=kl)
 def step(self):
     if self.with_modal:
         output, _ = self.model(torch.cat([self.rgb, self.modal], dim=1),
                                self.visible_mask4)
     else:
         output, _ = self.model(self.rgb, self.visible_mask3)
     if output.shape[2] != self.rgb.shape[2]:
         output = nn.functional.interpolate(output,
                                            size=self.rgb.shape[2:4],
                                            mode="bilinear",
                                            align_corners=True)
     loss_dict = self.criterion(self.rgb, self.visible_mask3, output,
                                self.rgb_gt)
     for k in loss_dict.keys():
         loss_dict[k] /= self.world_size
     loss = 0.0
     for key, coef in self.params['lambda_dict'].items():
         value = coef * loss_dict[key]
         loss += value
     self.optim.zero_grad()
     loss.backward()
     utils.average_gradients(self.model)
     self.optim.step()
     return loss_dict
Пример #11
0
def train(option="lstm", file_desc=""):
    epochs = FLAGS.epochs
    batchsize = FLAGS.batchsize
    shuffle_x = np.random.RandomState(42)
    shuffle_y = np.random.RandomState(42)

    task = CountingGame2()
    x, y = task.generate(length=FLAGS.seqlen, samples=FLAGS.samples)
    test_x, test_y = task.generate(length=FLAGS.seqlen, samples=1)

    sess = tf.Session()
    if option == "lstm":
        lstm = LSTM(sess, FLAGS.hidden, FLAGS.seqlen)
    elif option == "rnn":
        lstm = RNN(sess, FLAGS.hidden, FLAGS.seqlen)

    sess.run(tf.global_variables_initializer())

    lstm_weights = sess.run(lstm.cells[0].lstm_weights)
    lstm.load_weights(lstm_weights)

    n_iters = len(x) / batchsize
    for i in np.arange(epochs):
        shuffle_x.shuffle(x)
        shuffle_y.shuffle(y)
        for j in np.arange(n_iters):
            start = int(j * batchsize)
            end = int(start + batchsize)
            loss, lstm_gradients = lstm.fit(x[start:end], y[start:end])
            lstm_gradients = utils.average_gradients(lstm_gradients)
            lstm_weights = [
                lstm_weights[i] - FLAGS.lr * grad
                for i, grad in enumerate(lstm_gradients)
            ]
            dense_weights = sess.run(lstm.dense_weights)
            lstm.load_weights(lstm_weights)
        if i % 5 == 0:
            print("\nEpoch #{} Loss: {}".format(i, loss))
            print(test_x[0])
            predictions = lstm.test(test_x[0])
            print(np.argmax(predictions))
            with open("model/{}_lstm.pkl".format(file_desc), 'wb') as file:
                pickle.dump(lstm_weights, file)
            with open("model/{}_dense.pkl".format(file_desc), 'wb') as file:
                pickle.dump(dense_weights, file)
Пример #12
0
    def create_train_op(self, optim, iterator, global_step, regularizer_scale=1e-4, train=True, trainable=True, is_scale=True, training_mode='no_distillation'):  
        if self.num_gpus == 1:
            losses, regularizer_loss = self.build(iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale, training_mode=training_mode)
            optim_loss = losses['abs_robust_mean']['no_occlusion']
            train_op = optim.minimize(optim_loss, var_list=tf.trainable_variables(), global_step=global_step)            
        else:
            tower_grads = []
            tower_losses = []
            tower_regularizer_losses = []
            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(self.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        with tf.name_scope('tower_{}'.format(i)) as scope:
                            losses_, regularizer_loss_ = self.build(iterator, regularizer_scale=regularizer_scale, train=train, trainable=trainable, is_scale=is_scale, training_mode=training_mode) 
                            optim_loss = losses_['census']['occlusion'] + losses_['self_supervision']['self-supervision']
                            # optim_loss = losses_['census']['occlusion']

                            # Reuse variables for the next tower.
                            tf.get_variable_scope().reuse_variables()

                            grads = self.optim.compute_gradients(optim_loss, var_list=tf.trainable_variables())
                            tower_grads.append(grads)
                            tower_losses.append(losses_)
                            tower_regularizer_losses.append(regularizer_loss_)
                            #self.add_loss_summary(losses_, keys=['abs_robust_mean', 'census'], prefix='tower_%d' % i)
                                        
            grads = average_gradients(tower_grads)
            train_op = optim.apply_gradients(grads, global_step=global_step)
            
            losses = tower_losses[0].copy()
            for key in losses.keys():
                for loss_key, loss_value in losses[key].items():
                    for i in range(1, self.num_gpus):
                        losses[key][loss_key] += tower_losses[i][key][loss_key]
                    losses[key][loss_key] /= self.num_gpus
            regularizer_loss = 0.
            for i in range(self.num_gpus):
                regularizer_loss += tower_regularizer_losses[i]
            regularizer_loss /= self.num_gpus

        self.add_loss_summary(losses, keys=losses.keys())
        tf.summary.scalar('regularizer_loss', regularizer_loss)
        
        return train_op, losses, regularizer_loss
Пример #13
0
def train(train_queue, model, cnn_optimizer, grad_scalar, global_step,
          warmup_iters, writer, logging):
    alpha_i = utils.kl_balancer_coeff(num_scales=model.num_latent_scales,
                                      groups_per_scale=model.groups_per_scale,
                                      fun='square')
    nelbo = utils.AvgrageMeter()
    model.train()
    for step, x in enumerate(train_queue):
        x = x[0] if len(x) > 1 else x
        x = x.half().cuda()

        # change bit length
        x = utils.pre_process(x, args.num_x_bits)

        # warm-up lr
        if global_step < warmup_iters:
            lr = args.learning_rate * float(global_step) / warmup_iters
            for param_group in cnn_optimizer.param_groups:
                param_group['lr'] = lr

        # sync parameters, it may not be necessary
        if step % 100 == 0:
            utils.average_params(model.parameters(), args.distributed)

        cnn_optimizer.zero_grad()
        with autocast():
            logits, log_q, log_p, kl_all, kl_diag = model(x)

            output = model.decoder_output(logits)
            kl_coeff = utils.kl_coeff(
                global_step, args.kl_anneal_portion * args.num_total_iter,
                args.kl_const_portion * args.num_total_iter,
                args.kl_const_coeff)

            recon_loss = utils.reconstruction_loss(output,
                                                   x,
                                                   crop=model.crop_output)
            balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(
                kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i)

            nelbo_batch = recon_loss + balanced_kl
            loss = torch.mean(nelbo_batch)
            norm_loss = model.spectral_norm_parallel()
            bn_loss = model.batchnorm_loss()
            # get spectral regularization coefficient (lambda)
            if args.weight_decay_norm_anneal:
                assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.'
                wdn_coeff = (1. - kl_coeff) * np.log(
                    args.weight_decay_norm_init) + kl_coeff * np.log(
                        args.weight_decay_norm)
                wdn_coeff = np.exp(wdn_coeff)
            else:
                wdn_coeff = args.weight_decay_norm

            loss += norm_loss * wdn_coeff + bn_loss * wdn_coeff

        grad_scalar.scale(loss).backward()
        utils.average_gradients(model.parameters(), args.distributed)
        grad_scalar.step(cnn_optimizer)
        grad_scalar.update()
        nelbo.update(loss.data, 1)

        if (global_step + 1) % 100 == 0:
            if (global_step + 1) % 1000 == 0:  # reduced frequency
                n = int(np.floor(np.sqrt(x.size(0))))
                x_img = x[:n * n]
                output_img = output.mean if isinstance(
                    output, torch.distributions.bernoulli.Bernoulli
                ) else output.sample()
                output_img = output_img[:n * n]
                x_tiled = utils.tile_image(x_img, n)
                output_tiled = utils.tile_image(output_img, n)
                in_out_tiled = torch.cat((x_tiled, output_tiled), dim=2)
                writer.add_image('reconstruction', in_out_tiled, global_step)

            # norm
            writer.add_scalar('train/norm_loss', norm_loss, global_step)
            writer.add_scalar('train/bn_loss', bn_loss, global_step)
            writer.add_scalar('train/norm_coeff', wdn_coeff, global_step)

            utils.average_tensor(nelbo.avg, args.distributed)
            logging.info('train %d %f', global_step, nelbo.avg)
            writer.add_scalar('train/nelbo_avg', nelbo.avg, global_step)
            writer.add_scalar(
                'train/lr',
                cnn_optimizer.state_dict()['param_groups'][0]['lr'],
                global_step)
            writer.add_scalar('train/nelbo_iter', loss, global_step)
            writer.add_scalar('train/kl_iter', torch.mean(sum(kl_all)),
                              global_step)
            writer.add_scalar(
                'train/recon_iter',
                torch.mean(
                    utils.reconstruction_loss(output,
                                              x,
                                              crop=model.crop_output)),
                global_step)
            writer.add_scalar('kl_coeff/coeff', kl_coeff, global_step)
            total_active = 0
            for i, kl_diag_i in enumerate(kl_diag):
                utils.average_tensor(kl_diag_i, args.distributed)
                num_active = torch.sum(kl_diag_i > 0.1).detach()
                total_active += num_active

                # kl_ceoff
                writer.add_scalar('kl/active_%d' % i, num_active, global_step)
                writer.add_scalar('kl_coeff/layer_%d' % i, kl_coeffs[i],
                                  global_step)
                writer.add_scalar('kl_vals/layer_%d' % i, kl_vals[i],
                                  global_step)
            writer.add_scalar('kl/total_active', total_active, global_step)

        global_step += 1

    utils.average_tensor(nelbo.avg, args.distributed)
    return nelbo.avg, global_step
Пример #14
0
    def update(e):
        obs, act, adv, pos, ret, logp_old = [
            torch.Tensor(x) for x in buffer.retrieve_all()
        ]

        # Policy
        _, logp, _ = ac.policy(obs, act)
        entropy = (-logp).mean()

        # Policy loss
        pi_loss = -(logp * (k * adv + pos)).mean()

        # Train policy
        train_pi.zero_grad()
        pi_loss.backward()
        average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function
        v = ac.value_f(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            v = ac.value_f(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function train
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Discriminator
        if (e + 1) % train_dc_interv == 0:
            print('Discriminator Update!')
            con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()]
            _, logp_dc, _ = disc(s_diff, con)
            d_l_old = -logp_dc.mean()

            # Discriminator train
            for _ in range(train_dc_iters):
                _, logp_dc, _ = disc(s_diff, con)
                d_loss = -logp_dc.mean()
                train_dc.zero_grad()
                d_loss.backward()
                average_gradients(train_dc.param_groups)
                train_dc.step()

            _, logp_dc, _ = disc(s_diff, con)
            dc_l_new = -logp_dc.mean()
        else:
            d_l_old = 0
            dc_l_new = 0

        # Log the changes
        _, logp, _, v = ac(obs, act)
        pi_l_new = -(logp * (k * adv + pos)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()
        logger.store(LossPi=pi_loss,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=entropy,
                     DeltaLossPi=(pi_l_new - pi_loss),
                     DeltaLossV=(v_l_new - v_l_old),
                     LossDC=d_l_old,
                     DeltaLossDC=(dc_l_new - d_l_old))
Пример #15
0
def main(args, local_rank):

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    vocabs = dict()
    vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS])
    vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS])

    if args.world_size == 1 or (dist.get_rank() == 0):
        logger.info(args)
        for name in vocabs:
            logger.info("vocab %s, size %d, coverage %.3f", name,
                        vocabs[name].size, vocabs[name].coverage)

    set_seed(19940117)

    #device = torch.device('cpu')
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    if args.arch == 'vanilla':
        model = Generator(vocabs, args.embed_dim, args.ff_embed_dim,
                          args.num_heads, args.dropout, args.enc_layers,
                          args.dec_layers, args.label_smoothing)
    elif args.arch == 'mem':
        model = MemGenerator(vocabs, args.embed_dim, args.ff_embed_dim,
                             args.num_heads, args.dropout, args.mem_dropout,
                             args.enc_layers, args.dec_layers,
                             args.mem_enc_layers, args.label_smoothing,
                             args.use_mem_score)
    elif args.arch == 'rg':
        logger.info("start building model")
        logger.info("building retriever")
        retriever = Retriever.from_pretrained(
            args.num_retriever_heads,
            vocabs,
            args.retriever,
            args.nprobe,
            args.topk,
            local_rank,
            use_response_encoder=(args.rebuild_every > 0))

        logger.info("building retriever + generator")
        model = RetrieverGenerator(vocabs, retriever, args.share_encoder,
                                   args.embed_dim, args.ff_embed_dim,
                                   args.num_heads, args.dropout,
                                   args.mem_dropout, args.enc_layers,
                                   args.dec_layers, args.mem_enc_layers,
                                   args.label_smoothing)

    if args.resume_ckpt:
        model.load_state_dict(torch.load(args.resume_ckpt)['model'])
    else:
        global_step = 0

    if args.world_size > 1:
        set_seed(19940117 + dist.get_rank())

    model = model.to(device)

    retriever_params = [
        v for k, v in model.named_parameters() if k.startswith('retriever.')
    ]
    other_params = [
        v for k, v in model.named_parameters()
        if not k.startswith('retriever.')
    ]

    optimizer = Adam([{
        'params': retriever_params,
        'lr': args.embed_dim**-0.5 * 0.1
    }, {
        'params': other_params,
        'lr': args.embed_dim**-0.5
    }],
                     betas=(0.9, 0.98),
                     eps=1e-9)
    lr_schedule = get_inverse_sqrt_schedule_with_warmup(
        optimizer, args.warmup_steps, args.total_train_steps)
    train_data = DataLoader(vocabs,
                            args.train_data,
                            args.per_gpu_train_batch_size,
                            for_train=True,
                            rank=local_rank,
                            num_replica=args.world_size)

    model.eval()
    #dev_data = DataLoader(vocabs, cur_dev_data, args.dev_batch_size, for_train=False)
    #bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=10)

    step, epoch = 0, 0
    tr_stat = Statistics()
    logger.info("start training")
    model.train()

    best_dev_bleu = 0.
    while global_step <= args.total_train_steps:
        for batch in train_data:
            #step_start = time.time()
            batch = move_to_device(batch, device)
            if args.arch == 'rg':
                loss, acc = model(
                    batch,
                    update_mem_bias=(global_step >
                                     args.update_retriever_after))
            else:
                loss, acc = model(batch)

            tr_stat.update({
                'loss': loss.item() * batch['tgt_num_tokens'],
                'tokens': batch['tgt_num_tokens'],
                'acc': acc
            })
            tr_stat.step()
            loss.backward()
            #step_cost = time.time() - step_start
            #print ('step_cost', step_cost)
            step += 1
            if not (step % args.gradient_accumulation_steps
                    == -1 % args.gradient_accumulation_steps):
                continue

            if args.world_size > 1:
                average_gradients(model)

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_schedule.step()
            optimizer.zero_grad()
            global_step += 1

            if args.world_size == 1 or (dist.get_rank() == 0):
                if global_step % args.print_every == -1 % args.print_every:
                    logger.info("epoch %d, step %d, loss %.3f, acc %.3f",
                                epoch, global_step,
                                tr_stat['loss'] / tr_stat['tokens'],
                                tr_stat['acc'] / tr_stat['tokens'])
                    tr_stat = Statistics()
                if global_step % args.eval_every == -1 % args.eval_every:
                    model.eval()
                    max_time_step = 256 if global_step > 2 * args.warmup_steps else 5
                    bleus = []
                    for cur_dev_data in args.dev_data:
                        dev_data = DataLoader(vocabs,
                                              cur_dev_data,
                                              args.dev_batch_size,
                                              for_train=False)
                        bleu = validate(device,
                                        model,
                                        dev_data,
                                        beam_size=5,
                                        alpha=0.6,
                                        max_time_step=max_time_step)
                        bleus.append(bleu)
                    bleu = sum(bleus) / len(bleus)
                    logger.info("epoch %d, step %d, dev bleu %.2f", epoch,
                                global_step, bleu)
                    if bleu > best_dev_bleu:
                        testbleus = []
                        for cur_test_data in args.test_data:
                            test_data = DataLoader(vocabs,
                                                   cur_test_data,
                                                   args.dev_batch_size,
                                                   for_train=False)
                            testbleu = validate(device,
                                                model,
                                                test_data,
                                                beam_size=5,
                                                alpha=0.6,
                                                max_time_step=max_time_step)
                            testbleus.append(testbleu)
                        testbleu = sum(testbleus) / len(testbleus)
                        logger.info("epoch %d, step %d, test bleu %.2f", epoch,
                                    global_step, testbleu)
                        torch.save({
                            'args': args,
                            'model': model.state_dict()
                        }, '%s/best.pt' % (args.ckpt, ))
                        if not args.only_save_best:
                            torch.save(
                                {
                                    'args': args,
                                    'model': model.state_dict()
                                },
                                '%s/epoch%d_batch%d_devbleu%.2f_testbleu%.2f' %
                                (args.ckpt, epoch, global_step, bleu,
                                 testbleu))
                        best_dev_bleu = bleu
                    model.train()

            if args.rebuild_every > 0 and (global_step % args.rebuild_every
                                           == -1 % args.rebuild_every):
                model.retriever.drop_index()
                torch.cuda.empty_cache()
                next_index_dir = '%s/batch%d' % (args.ckpt, global_step)
                if args.world_size == 1 or (dist.get_rank() == 0):
                    model.retriever.rebuild_index(next_index_dir)
                    dist.barrier()
                else:
                    dist.barrier()
                model.retriever.update_index(next_index_dir, args.nprobe)

            if global_step > args.total_train_steps:
                break
        epoch += 1
    logger.info('rank %d, finish training after %d steps', local_rank,
                global_step)
Пример #16
0
    def build(self):
        self.train_phase_dropout = tf.placeholder(dtype=tf.bool, shape=None, name='train_phase_dropout')
        self.train_phase_bn = tf.placeholder(dtype=tf.bool, shape=None, name='train_phase_bn')
        self.global_step = tf.Variable(name='global_step', initial_value=0, trainable=False)
        self.inc_op = tf.assign_add(self.global_step, 1, name='increment_global_step')
        scale = int(512.0/self.batch_size)
        lr_steps = [scale*s for s in self.config['lr_steps']]
        lr_values = [v/scale for v in self.config['lr_values']]
        # lr_steps = self.config['lr_steps']
        self.lr = tf.train.piecewise_constant(self.global_step, boundaries=lr_steps, values=lr_values, name='lr_schedule')

        cid = ClassificationImageData(img_size=self.image_size, augment_flag=self.config['augment_flag'], augment_margin=self.config['augment_margin'])
        train_dataset = cid.read_TFRecord(self.config['train_data']).shuffle(10000).repeat().batch(self.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        self.train_images, self.train_labels = train_iterator.get_next()
        self.train_images = tf.identity(self.train_images, 'input_images')
        self.train_labels = tf.identity(self.train_labels, 'labels')
        if self.gpu_num <= 1:
            self.embds, self.logits, self.end_points = inference(self.train_images, self.train_labels, self.train_phase_dropout, self.train_phase_bn, self.config)
            self.embds = tf.identity(self.embds, 'embeddings')
            self.inference_loss = slim.losses.sparse_softmax_cross_entropy(logits=self.logits, labels=self.train_labels)
            self.wd_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
            self.train_loss = self.inference_loss+self.wd_loss
            pred = tf.arg_max(tf.nn.softmax(self.logits), dimension=-1, output_type=tf.int64)
            self.train_acc = tf.reduce_mean(tf.cast(tf.equal(pred, self.train_labels), tf.float32))
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_op = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=self.config['momentum']).minimize(self.train_loss)
        else:
            self.embds = []
            self.logits = []
            self.inference_loss = []
            self.wd_loss = []
            self.train_loss = []
            pred = []
            tower_grads = []
            update_ops = []
            opt = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=self.config['momentum'])
            train_images = tf.split(self.train_images, self.gpu_num)
            train_labels = tf.split(self.train_labels, self.gpu_num)
            for i in range(self.gpu_num):
                sub_train_images = train_images[i]
                sub_train_labels = train_labels[i]
                with tf.device('/gpu:%d' % i):
                    with tf.variable_scope(tf.get_variable_scope(), reuse=(i > 0)):
                        embds, logits, end_points = inference(sub_train_images, sub_train_labels, self.train_phase_dropout, self.train_phase_bn, self.config)
                        inference_loss = slim.losses.sparse_softmax_cross_entropy(logits=logits, labels=sub_train_labels)
                        wd_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
                        train_loss = inference_loss+wd_loss
                        pred.append(tf.arg_max(tf.nn.softmax(logits), dimension=-1, output_type=tf.int64))
                        tower_grads.append(opt.compute_gradients(train_loss))
                        update_ops.append(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
                        self.embds.append(embds)
                        self.logits.append(logits)
                        self.inference_loss.append(inference_loss)
                        self.wd_loss.append(wd_loss)
                        self.train_loss.append(train_loss)
            self.embds = tf.concat(self.embds, axis=0)
            self.logits = tf.concat(self.logits, axis=0)
            self.inference_loss = tf.add_n(self.inference_loss)/self.gpu_num
            self.wd_loss = tf.add_n(self.wd_loss)/self.gpu_num
            self.train_loss = tf.add_n(self.train_loss)/self.gpu_num
            pred = tf.concat(pred, axis=0)
            self.train_acc = tf.reduce_mean(tf.cast(tf.equal(pred, self.train_labels), tf.float32))
            train_ops = [opt.apply_gradients(average_gradients(tower_grads))]
            train_ops.extend(update_ops)
            self.train_op = tf.group(*train_ops)


        self.train_summary = tf.summary.merge([
            tf.summary.scalar('inference_loss', self.inference_loss),
            tf.summary.scalar('wd_loss', self.wd_loss),
            tf.summary.scalar('train_loss', self.train_loss),
            tf.summary.scalar('train_acc', self.train_acc)
        ])
    def __init__(self,
                 iterator,
                 session,
                 model,
                 num_classes,
                 optimizer,
                 dataset,
                 p_norm=2.,
                 alpha=None,
                 decomp_type='bior2.2',
                 NUMPY_images=None,
                 NUMPY_labels=None,
                 learning_rate=.001,
                 weight_decay_p=.0001,
                 lp_wavelet_p=.0001,
                 batch_size=32,
                 bn_momentum=.99,
                 robust_regularization=True,
                 use_wavelet_decomposition=True,
                 wavelet_weights=[0, 1],
                 sensitivity_mode='logits',
                 graph=tf.get_default_graph()):

        self.iterator = iterator
        self.session = session
        self.model = model
        self.num_classes = num_classes
        self.optimizer = optimizer
        self.dataset = dataset
        self.robust_regularization = robust_regularization
        self.wavelet_weights = wavelet_weights
        self.nested_wavelet_weights = utils.nested_weight_list(wavelet_weights)
        self.sensitivity_mode = sensitivity_mode
        self.graph = graph
        self.decomp_type = decomp_type

        self.decomp_depth = len(wavelet_weights) - 1
        self.learning_rate = learning_rate
        self.weight_decay_p = weight_decay_p
        self.lp_wavelet_p = lp_wavelet_p
        self.batch_size = batch_size
        self.bn_momentum = bn_momentum
        self.graph = tf.get_default_graph()
        self.p_norm = p_norm

        self.alpha = alpha
        self.NUMPY_images = NUMPY_images
        self.NUMPY_labels = NUMPY_labels

        if use_wavelet_decomposition:
            from fwt import multi_channel_fwt, create_filter_bank
            self.decomp_filters, self.reconst_filters = create_filter_bank(
                decomp_type)

        devices = device_lib.list_local_devices()
        GPU_devices = [dev.name for dev in devices if dev.device_type == 'GPU']
        self.num_GPUs = len(GPU_devices)

        tensors = []
        scalars = []
        gradients = []
        summaries = []
        with tf.variable_scope(tf.get_variable_scope()):
            with session.as_default():
                for dev in range(self.num_GPUs):
                    with tf.device('/device:GPU:%d' % dev):
                        with tf.name_scope('GPU_%d' % dev) as scope:
                            print("Compiling on GPU %d ..." % dev)

                            tensors.append(dict())
                            scalars.append(dict())

                            # scalars finished converting to dict:
                            # mean_NLL, sum_of_true_logits, mean_correlations

                            # Get the inputs from the iterators
                            next_element = iterator.get_next()
                            tensors[-1]['images'] = next_element[0]
                            tensors[-1]['targets'] = next_element[1]
                            tensors[-1]['one_hot_targets'] = tf.one_hot(
                                tensors[-1]['targets'], self.num_classes)

                            # Get the forward propagated output
                            # for the current batch of this GPU.
                            network_output = model(tensors[-1]['images'])
                            tensors[-1]['logits'] = network_output

                            # For neural networks that use batch
                            # normalization, network_output is actually
                            # a list of tensors, where logits[1:]
                            # represent the inputs to the BatchNorm
                            # layers. Here, we handle this situation
                            # if it arises.
                            if type(network_output) == list:
                                tensors[-1]['logits'] = network_output[0]
                                bn_inputs = network_output[1:]
                                utils.add_bn_ops(model,
                                                 bn_inputs,
                                                 bn_momentum=bn_momentum)

                            tensors[-1]['predictions'] = tf.argmax(
                                tensors[-1]['logits'], axis=1)
                            tensors[-1][
                                'predicted_one_hot_targets'] = tf.one_hot(
                                    tensors[-1]['predictions'],
                                    self.num_classes)
                            tensors[-1]['predicted_logits'] = tf.reduce_max(
                                tensors[-1]['logits'], axis=1)
                            tensors[-1]['probabilities'] = tf.nn.softmax(
                                tensors[-1]['logits'])

                            #### x-terms, b-terms ####################

                            tensors[-1]['x_terms'] = Rop(
                                tensors[-1]['logits'], tensors[-1]['images'],
                                tensors[-1]['images'])
                            tensors[-1]['b_terms'] = tensors[-1][
                                'logits'] - tensors[-1]['x_terms']
                            tensors[-1]['predicted_b_terms'] = utils.select(
                                tensors[-1]['b_terms'],
                                tensors[-1]['predictions'], self.num_classes)

                            if self.alpha is not None:
                                tensors[-1]['taus'] = tensors[-1][
                                    'logits'] - self.alpha * tensors[-1][
                                        'x_terms']

                            #NUMPY SECTION
                            if NUMPY_images is not None and NUMPY_labels is not None:
                                NUMPY_network_output = model(NUMPY_images)
                                tensors[-1][
                                    'NUMPY_logits'] = NUMPY_network_output
                                if type(NUMPY_network_output) == list:
                                    tensors[-1][
                                        'NUMPY_logits'] = NUMPY_network_output[
                                            0]
                                tensors[-1]['NUMPY_predictions'] = tf.argmax(
                                    tensors[-1]['NUMPY_logits'], axis=1)

                                tensors[-1]['NUMPY_x_terms'] = Rop(
                                    tensors[-1]['NUMPY_logits'], NUMPY_images,
                                    NUMPY_images)
                                tensors[-1]['NUMPY_b_terms'] = tensors[-1][
                                    'NUMPY_logits'] - tensors[-1][
                                        'NUMPY_x_terms']

                                tensors[-1][
                                    'NUMPY_selected_x_terms'] = utils.select(
                                        tensors[-1]['NUMPY_x_terms'],
                                        NUMPY_labels, self.num_classes)
                                tensors[-1][
                                    'NUMPY_selected_b_terms'] = utils.select(
                                        tensors[-1]['NUMPY_b_terms'],
                                        NUMPY_labels, self.num_classes)

                                if self.alpha is not None:
                                    NUMPY_taus = tensors[-1][
                                        'NUMPY_logits'] - self.alpha * tensors[
                                            -1]['NUMPY_x_terms']

                                tensors[-1][
                                    'NUMPY_selected_logits'] = utils.select(
                                        tensors[-1]['NUMPY_logits'],
                                        NUMPY_labels, self.num_classes)

                                tensors[-1][
                                    'NUMPY_logit_sensitivities'] = tf.gradients(
                                        tf.reduce_sum(
                                            tensors[-1]
                                            ['NUMPY_selected_logits']),
                                        NUMPY_images)[0]
                                tensors[-1][
                                    'NUMPY_bias_shifted_images'] = bias_shifted_input(
                                        NUMPY_images,
                                        tensors[-1]['NUMPY_selected_b_terms'],
                                        tensors[-1]
                                        ['NUMPY_logit_sensitivities'])

                            ##########################################

                            # Classification loss
                            tensors[-1][
                                'NLLs'] = tf.nn.softmax_cross_entropy_with_logits_v2(
                                    labels=tensors[-1]['one_hot_targets'],
                                    logits=tensors[-1]['logits'])
                            scalars[-1]['mean_NLL'] = tf.reduce_mean(
                                tensors[-1]['NLLs'])

                            # Setting up the sensitivity penalty.
                            if sensitivity_mode == 'logits':
                                scalars[-1][
                                    'sum_of_true_logits'] = tf.reduce_sum(
                                        tensors[-1]['logits'] *
                                        tensors[-1]['one_hot_targets'])
                                tensors[-1]['sensitivities'] = tf.gradients(
                                    scalars[-1]['sum_of_true_logits'],
                                    tensors[-1]['images'],
                                    name='input_gradients')[0]
                            elif sensitivity_mode == 'NLL':
                                tensors[-1]['sensitivities'] = tf.gradients(
                                    scalars[-1]['mean_NLL'],
                                    tensors[-1]['images'],
                                    name='input_gradients')[0]

                            if use_wavelet_decomposition:
                                sensitivity_w_decomp = multi_channel_fwt(
                                    tensors[-1]['sensitivities'],
                                    self.decomp_filters,
                                    self.decomp_depth,
                                    output_type='list')

                            tensors[-1]['inner_products'] = tf.reduce_sum(
                                tensors[-1]['images'] *
                                tensors[-1]['sensitivities'],
                                axis=[1, 2, 3])

                            tensors[-1]['sensitivity_norms'] = tf.sqrt(
                                tf.reduce_sum(tensors[-1]['sensitivities']**2,
                                              axis=[1, 2, 3],
                                              name='sens_norm'))
                            tensors[-1]['image_norms'] = tf.sqrt(
                                tf.reduce_sum(tensors[-1]['images']**2,
                                              axis=[1, 2, 3],
                                              name='im_norm'))

                            tensors[-1]['norm_products'] = tensors[-1][
                                'sensitivity_norms'] * tensors[-1][
                                    'image_norms']

                            epsilon = 0.0
                            tensors[-1]['correlations'] = tensors[-1][
                                'inner_products'] / (
                                    tensors[-1]['norm_products'] + epsilon)

                            scalars[-1]['mean_correlation'] = tf.reduce_mean(
                                tensors[-1]['correlations'])
                            scalars[-1]['mean_inner_product'] = tf.reduce_mean(
                                tensors[-1]['inner_products'])
                            scalars[-1]['mean_norm_product'] = tf.reduce_mean(
                                tensors[-1]['norm_products'])

                            tensors[-1]['true_logits'] = tf.reduce_sum(
                                tensors[-1]['logits'] *
                                tensors[-1]['one_hot_targets'],
                                axis=1)
                            scalars[-1]['sum_of_true_logits'] = tf.reduce_sum(
                                tensors[-1]['true_logits'])
                            tensors[-1]['logit_sensitivities'] = tf.gradients(
                                scalars[-1]['sum_of_true_logits'],
                                tensors[-1]['images'],
                                name='logit_input_gradients')[0]

                            tensors[-1][
                                'logit_inner_products'] = tf.reduce_sum(
                                    tensors[-1]['images'] *
                                    tensors[-1]['logit_sensitivities'],
                                    axis=[1, 2, 3])

                            tensors[-1]['logit_sensitivity_norms'] = tf.sqrt(
                                tf.reduce_sum(
                                    tensors[-1]['logit_sensitivities']**2,
                                    axis=[1, 2, 3],
                                    name='sens_norm'))

                            tensors[-1]['logit_norm_products'] = tensors[-1][
                                'logit_sensitivity_norms'] * tensors[-1][
                                    'image_norms']

                            tensors[-1]['logit_correlations'] = tensors[-1]['logit_inner_products'] / \
                                (tensors[-1]['logit_norm_products'] + epsilon)

                            scalars[-1][
                                'mean_logit_correlation'] = tf.reduce_mean(
                                    tensors[-1]['logit_correlations'])
                            scalars[-1][
                                'mean_logit_inner_product'] = tf.reduce_mean(
                                    tensors[-1]['logit_inner_products'])
                            scalars[-1][
                                'mean_logit_norm_product'] = tf.reduce_mean(
                                    tensors[-1]['logit_norm_products'])

                            # Again as a tiled image, for visualization.
                            # Only do this if the dimensions work out.
                            tiled_image_works = False
                            if use_wavelet_decomposition:
                                try:
                                    tensors[-1][
                                        'sensitivity_w_decomp_imgs'] = multi_channel_fwt(
                                            tensors[-1]['sensitivities'],
                                            self.decomp_filters,
                                            self.decomp_depth,
                                            output_type='image')
                                    tiled_image_works = True
                                except tf.errors.OpError:
                                    print(
                                        "Creating a tiled wavelet image failed."
                                    )

                            # sum up all the p-norms of the FWTs of
                            # all channels.
                            if use_wavelet_decomposition:
                                sensitivity_w_mean_lp = 0
                                for decomp in sensitivity_w_decomp:
                                    sensitivity_w_mean_lp += utils.lp_norm_weighted(
                                        decomp,
                                        self.nested_wavelet_weights,
                                        p_norm=self.p_norm)
                            else:
                                # Otherwise, just calculate the p-norm of the
                                # sensitivity.
                                sensitivity_w_mean_lp = utils.lp_norm(
                                    tensors[-1]['sensitivities'],
                                    p_norm=self.p_norm)

                            scalars[-1][
                                'sensitivity_w_mean_lp'] = sensitivity_w_mean_lp

                            ############ ONLY FOR LOGGING PURPOSES ###################
                            tensors[-1]['random_targets'] = tf.random_uniform(
                                tf.shape(tensors[-1]['targets']),
                                maxval=self.num_classes - 1,
                                dtype=tf.int32)

                            tensors[-1]['random_one_hot_targets'] = tf.one_hot(
                                tensors[-1]['random_targets'],
                                self.num_classes)
                            tensors[-1]['random_logits'] = tf.reduce_sum(
                                tensors[-1]['logits'] *
                                tensors[-1]['random_one_hot_targets'],
                                axis=1)
                            scalars[-1][
                                'sum_of_random_logits'] = tf.reduce_sum(
                                    tensors[-1]['random_logits'])

                            tensors[-1][
                                'random_logit_sensitivities'] = tf.gradients(
                                    scalars[-1]['sum_of_random_logits'],
                                    tensors[-1]['images'],
                                    name='random_logit_sensitivities')[0]
                            tensors[-1][
                                'random_logit_inner_products'] = tf.reduce_sum(
                                    tensors[-1]['images'] *
                                    tensors[-1]['random_logit_sensitivities'],
                                    axis=[1, 2, 3])
                            tensors[-1][
                                'random_logit_sensitivity_norms'] = tf.sqrt(
                                    tf.reduce_sum(
                                        tensors[-1]
                                        ['random_logit_sensitivities']**2,
                                        axis=[1, 2, 3]))

                            scalars[-1][
                                'sum_of_predicted_logits'] = tf.reduce_sum(
                                    tensors[-1]['predicted_logits'])
                            tensors[-1][
                                'predicted_logit_sensitivities'] = tf.gradients(
                                    scalars[-1]['sum_of_predicted_logits'],
                                    tensors[-1]['images'],
                                    name='predicted_logit_sensitivities')[0]
                            tensors[-1][
                                'predicted_logit_inner_products'] = tf.reduce_sum(
                                    tensors[-1]['images'] * tensors[-1]
                                    ['predicted_logit_sensitivities'],
                                    axis=[1, 2, 3])
                            tensors[-1][
                                'predicted_logit_sensitivity_norms'] = tf.sqrt(
                                    tf.reduce_sum(
                                        tensors[-1]
                                        ['predicted_logit_sensitivities']**2,
                                        axis=[1, 2, 3]))

                            tensors[-1]['true_logit_sensitivities'] = tensors[
                                -1]['logit_sensitivities']
                            tensors[-1][
                                'true_logit_inner_products'] = tf.reduce_sum(
                                    tensors[-1]['images'] *
                                    tensors[-1]['true_logit_sensitivities'],
                                    axis=[1, 2, 3])
                            tensors[-1][
                                'true_logit_sensitivity_norms'] = tf.sqrt(
                                    tf.reduce_sum(
                                        tensors[-1]['true_logit_sensitivities']
                                        **2,
                                        axis=[1, 2, 3]))

                            # Calculate the bias gradients
                            flatten = lambda a: tf.reshape(a, (-1, ))
                            IP = lambda a, b: tf.reduce_sum(a * b)

                            biases = [
                                b for b in model.trainable_weights
                                if 'bias' in b.name
                            ]
                            biases += tf.get_collection('bn_betas')
                            biases += tf.get_collection('bn_means')

                            random_bias_gradients = tf.gradients(
                                scalars[-1]['sum_of_random_logits'],
                                biases,
                                name='random_bias_gradients')

                            random_bg = [
                                IP(flatten(b), flatten(g))
                                for (b,
                                     g) in zip(biases, random_bias_gradients)
                            ]
                            random_bias_inner_products = tf.accumulate_n(
                                random_bg)

                            predicted_bias_gradients = tf.gradients(
                                scalars[-1]['sum_of_predicted_logits'],
                                biases,
                                name='predicted_bias_gradients')
                            predicted_bg = [
                                IP(flatten(b), flatten(g)) for (
                                    b,
                                    g) in zip(biases, predicted_bias_gradients)
                            ]
                            predicted_bias_inner_products = tf.accumulate_n(
                                predicted_bg)

                            true_bias_gradients = tf.gradients(
                                scalars[-1]['sum_of_true_logits'],
                                biases,
                                name='true_bias_gradients')

                            true_bg = [
                                IP(flatten(b), flatten(g))
                                for (b, g) in zip(biases, true_bias_gradients)
                            ]
                            true_bias_inner_products = tf.add_n(true_bg)

                            zero_image = tf.zeros_like(tensors[-1]['images'])
                            tensors[-1]['zero_output'] = model(zero_image)[0]

                            tensors[-1]['random_zero_logits'] = tf.reduce_sum(
                                tensors[-1]['zero_output'] *
                                tensors[-1]['random_one_hot_targets'],
                                axis=1)
                            tensors[-1][
                                'predicted_zero_logits'] = tf.reduce_sum(
                                    tensors[-1]['zero_output'] *
                                    tensors[-1]['predicted_one_hot_targets'],
                                    axis=1)
                            tensors[-1]['true_zero_logits'] = tf.reduce_sum(
                                tensors[-1]['zero_output'] *
                                tensors[-1]['one_hot_targets'],
                                axis=1)

                            # Calculate the approximate random robustness

                            tensors[-1]['inner_product_differences'] = (
                                tensors[-1]['predicted_logit_inner_products'] -
                                tensors[-1]['random_logit_inner_products'])

                            tensors[-1][
                                'bias_differences'] = predicted_bias_inner_products - random_bias_inner_products

                            numerator = tensors[-1][
                                'inner_product_differences'] - tensors[-1][
                                    'bias_differences']

                            tensors[-1]['logit_sensitivity_differences'] = (
                                tensors[-1]['predicted_logit_sensitivities'] -
                                tensors[-1]['random_logit_sensitivities'])
                            denominator = tf.sqrt(
                                tf.reduce_sum(
                                    tensors[-1]
                                    ['logit_sensitivity_differences']**2))

                            tensors[-1][
                                'approximate_random_robustness'] = numerator / denominator
                            tensors[-1][
                                'inner_product_differences_normalized'] = (
                                    tensors[-1]['inner_product_differences'] /
                                    denominator)
                            tensors[-1][
                                'bias_differences_normalized'] = tensors[-1][
                                    'bias_differences'] / denominator

                            tensors[-1][
                                'bias_difference_shifted_images'] = bias_shifted_input(
                                    tensors[-1]['images'],
                                    tensors[-1]['bias_differences'],
                                    tensors[-1]
                                    ['logit_sensitivity_differences'])

                            #print(tensors[-1]['bias_differences_normalized'])
                            #crash()
                            #######################################################

                            # Collect the network's weights and set up
                            # the weight decay penalty
                            trainable_weights = model.trainable_weights
                            scalars[-1]['weight_norm'] = tf.add_n([
                                tf.reduce_sum(w**2) for w in trainable_weights
                            ])

                            # Assemble the total loss for this GPU
                            scalars[-1]['total_loss'] = scalars[-1]['mean_NLL']
                            scalars[-1][
                                'total_loss'] += weight_decay_p * scalars[-1][
                                    'weight_norm']
                            if robust_regularization:
                                scalars[-1][
                                    'sensitivity_penalty'] = lp_wavelet_p * scalars[
                                        -1]['sensitivity_w_mean_lp']
                                scalars[-1]['total_loss'] += scalars[-1][
                                    'sensitivity_penalty']

                            # Everything that is tracked during training
                            # goes here. Top-5 and top-1 accuracies are
                            # automatically added.
                            summary_dict = {
                                'total_loss':
                                scalars[-1]['total_loss'],
                                'mean_NLL':
                                scalars[-1]['mean_NLL'],
                                'weight_2_norm_squared':
                                scalars[-1]['weight_norm'],
                                'mean_sensitivity_wavelet_coeffs_lp':
                                scalars[-1]['sensitivity_w_mean_lp']
                            }

                            # Add some hyperparameters, too.
                            # Some redundant calculations through averaging
                            # later, but the computational overhead is negligible.
                            summary_dict['learning_rate_'] = learning_rate
                            summary_dict['correlation_'] = scalars[-1][
                                'mean_correlation']
                            summary_dict['inner_product_'] = scalars[-1][
                                'mean_inner_product']
                            summary_dict['norm_product_'] = scalars[-1][
                                'mean_norm_product']
                            summary_dict['logit_correlation_'] = scalars[-1][
                                'mean_logit_correlation']
                            summary_dict['logit_inner_product_'] = scalars[-1][
                                'mean_logit_inner_product']
                            summary_dict['logit_norm_product_'] = scalars[-1][
                                'mean_logit_norm_product']
                            summary_dict[
                                'weight_decay_parameter_'] = weight_decay_p
                            summary_dict[
                                'lp_Wavelet_parameter_'] = lp_wavelet_p
                            summary_dict[
                                'total_batch_size'] = batch_size * self.num_GPUs
                            summary_dict['bn_momentum_'] = bn_momentum
                            summary_dict['p_norm'] = p_norm

                            if robust_regularization:
                                summary_dict['sensitivity_penalty'] = scalars[
                                    -1]['sensitivity_penalty']

                            summary_dict = summary_utils.prepare_summaries(
                                summary_dict=summary_dict,
                                predictions=tensors[-1]['probabilities'],
                                labels=tensors[-1]['targets'])
                            summaries.append(summary_dict)

                            # Collect the gradients for every GPU
                            gradients.append(
                                optimizer.compute_gradients(
                                    scalars[-1]['total_loss'],
                                    var_list=trainable_weights,
                                    colocate_gradients_with_ops=True))

                            # So far, the adversarial attack model is only
                            # created on one GPU. Different parallelized versions
                            # always lead to errors.
                            if dev == 0:
                                self.adversarial_model = TensorFlowModel(
                                    tensors[-1]['images'],
                                    tensors[-1]['logits'],
                                    bounds=self.dataset.bounds)

        print("Done.")

        # Copy the lists 'tensors' and 'scalars' and replace these with an aggregated version:
        # Concatenate the tensors and average the scalars.
        self.tensors = dict()
        self.scalars = dict()
        for key in tensors[0].keys():
            print(key)
            self.tensors[key] = tf.concat(
                [tensors_item[key] for tensors_item in tensors], axis=0)
        for key in scalars[0].keys():
            self.scalars[key] = tf.reduce_mean(
                [scalars_item[key] for scalars_item in scalars])

        # Create self.GPU_collections for backwards compatibility
        self.GPU_collections = {**self.tensors, **self.scalars}
        self.GPU_collections['top_1'] = tf.concat(tf.get_collection('top_1'),
                                                  0)
        self.GPU_collections['top_5'] = tf.concat(tf.get_collection('top_5'),
                                                  0)

        # Collection and apply the gradients over all used
        # GPUs for synchronous parallel training.
        avg_grads = utils.average_gradients(gradients)
        gradient_application = optimizer.apply_gradients(avg_grads)
        # We combine the gradient update and possibly the
        # batch normalization update operators into one.
        self.train_op = tf.group(gradient_application,
                                 *(tf.get_collection('bn_update_ops')))

        summary_dict = summary_utils.collect_summaries(summaries)
        self.summary_op = summary_utils.create_summary_op(summary_dict)

        if use_wavelet_decomposition:
            wavelet_summary = tf.summary.tensor_summary(
                'wavelet_weights', self.wavelet_weights)
            self.summary_op = tf.summary.merge(
                [self.summary_op, wavelet_summary])

        # Here, we create a tiled image summary for Tensorboard.
        # We hereby shift the range of the sensitivity and
        # possibly its decomposition to the range of the image.
        image_range = self.dataset.image_range()
        image_max = image_range[1]
        image_min = image_range[0]
        image_span = image_max - image_min
        image_mid = image_span / 2.

        self.images = self.dataset.interpret_as_image(
            self.GPU_collections['images'])
        self.saliencies = self.GPU_collections['sensitivities']
        saliencies_max = tf.reduce_max(tf.abs(self.saliencies), [1, 2],
                                       keepdims=True)
        normalized_saliencies = image_span * self.saliencies / \
            (2*saliencies_max + 1e-9) + image_mid

        if use_wavelet_decomposition:
            self.saliency_decomps = self.GPU_collections[
                'sensitivity_w_decomp_imgs']
            saliency_decomps_max = tf.reduce_max(tf.abs(self.saliency_decomps),
                                                 [1, 2],
                                                 keepdims=True)
            normalized_decomps = image_span * self.saliency_decomps / \
                (2*saliency_decomps_max + 1e-9) + image_mid

        composite_image = [self.images, normalized_saliencies]

        if tiled_image_works:
            composite_image.append(normalized_decomps)

        img_saliency_decomp = tf.concat(composite_image, 2)

        self.img_summary_op = tf.summary.image('img_saliency_decomp',
                                               img_saliency_decomp,
                                               max_outputs=10)
Пример #18
0
        def training():
            # Split the batch across GPU's
            left_1_splits = tf.split(left_1_batch, args.num_gpus, 0)
            right_1_splits = tf.split(right_1_batch, args.num_gpus, 0)
            left_2_splits = tf.split(left_2_batch, args.num_gpus, 0)

            tower_grads = []
            tower_losses = []
            reuse_variables = False
            use_lstm = not args.no_lstm

            with tf.variable_scope(tf.get_variable_scope()):
                for i in range(args.num_gpus):
                    with tf.device('/gpu:%d' % i):
                        print(left_1_splits[i], '^^^^^^^^^^^^^^^^^^^^^^^^^^^')
                        model = MonoVODModel(params, args.mode,
                                             left_1_splits[i],
                                             right_1_splits[i],
                                             left_2_splits[i], reuse_variables,
                                             i)
                        loss = model.total_loss
                        tower_losses.append(loss)

                        reuse_variables = True

                        grads = opt_step.compute_gradients(loss)

                        tower_grads.append(grads)

            grads = average_gradients(tower_grads)

            # BACKPROPAGATION
            apply_gradient_op = opt_step.apply_gradients(
                grads, global_step=global_step)

            total_loss = tf.reduce_mean(tower_losses)

            tf.summary.scalar('learning_rate', learning_rate, ['model_0'])
            tf.summary.scalar('total_loss', total_loss, ['model_0'])
            summary_op = tf.summary.merge_all('model_0')

            # SESSION
            config = tf.ConfigProto(allow_soft_placement=True)
            sess = tf.Session(config=config)

            # SAVER
            summary_writer = tf.summary.FileWriter(
                args.log_directory + '/' + args.model_name, sess.graph)
            train_saver = tf.train.Saver()

            # COUNT PARAMS
            total_num_parameters = 0
            for variable in tf.trainable_variables():
                total_num_parameters += np.array(
                    variable.get_shape().as_list()).prod()
            print("number of trainable parameters: {}".format(
                total_num_parameters))

            # INIT
            sess.run(tf.global_variables_initializer())
            sess.run(init_op)

            # LOAD CHECKPOINT IF SET
            if args.checkpoint_path != '':
                train_saver.restore(sess, args.checkpoint_path.split(".")[0])

                if args.retrain:
                    sess.run(global_step.assign(0))

            # DEBUGGER
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            save_after_n_steps = args.save_after

            # GO!
            print('Training...')
            start_step = global_step.eval(session=sess)
            start_time = time.time()
            for step in range(start_step, num_total_steps):
                before_op_time = time.time()
                _, loss_value = sess.run([apply_gradient_op, total_loss])

                duration = time.time() - before_op_time
                if step and step % 100 == 0:
                    examples_per_sec = params.batch_size / duration
                    time_sofar = (time.time() - start_time) / 3600
                    training_time_left = (num_total_steps / step -
                                          1.0) * time_sofar
                    print_string = 'batch {:>6} | examples/s: {:4.2f} | loss: {:.5f} | time elapsed: {:.2f}h | time left: {:.2f}h'
                    print(
                        print_string.format(step, examples_per_sec, loss_value,
                                            time_sofar, training_time_left))
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, global_step=step)

                # Save
                if step and step % save_after_n_steps == 0:
                    train_saver.save(sess,
                                     args.log_directory + '/' +
                                     args.model_name + '/model',
                                     global_step=step)

            train_saver.save(sess,
                             args.log_directory + '/' + args.model_name +
                             '/model',
                             global_step=num_total_steps)
Пример #19
0
def main(args):

    with tf.device("/cpu:0"):
        global_step = tf.Variable(0, trainable=False)
        images = tf.placeholder(tf.float32, shape=[None, args.image_height, args.image_width, 3], name="img_inputs")
        labels = tf.placeholder(tf.int64, shape=[None], name="img_labels")
        dropout_rate = tf.placeholder(tf.float32, name="dropout")

        # define learning rate schedule
        p = int(512.0/args.batch_size)
        # lr_steps = [p*val for val in args.lr_steps]
        lr_steps = [5000, 10000]
        lr = tf.train.piecewise_constant(global_step, boundaries=lr_steps, values=[0.0001, 0.00005, 0.00003], name='lr_schedule')
        # define optimize method
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)
        
        tower_grads = []
        loss_dict = {}
        loss_keys = []
        reuse_vars = False

        # Loop over all GPUs and construct their own computation graph
        for i in range(len(args.gpu_device)):
            with tf.device("/gpu:%d"%args.gpu_device[i]):
                # split data between GPUs
                _x = images[i * args.batch_size: (i+1) * args.batch_size]
                _y = labels[i * args.batch_size: (i+1) * args.batch_size]
                # print(_x.shape)
                # print(_y.shape)
                
                # Because Dropout have different behavior at training and prediction time, we
                # need to create 2 distinct computation graphs that share the same weights.
                
                # Create a graph for training
                ax1_out_train, ax2_out_train, main_out_train = inception_v1(_x, args.dropout, is_training=True, reuse=reuse_vars)
                
                # Create another graph for testing that reuse the same weights
                ax1_out_test, ax2_out_test, main_out_test = inception_v1(_x, args.dropout, is_training=False, reuse=True)
                
                
                # tf.nn.sparse_softmax_cross_entropy_with_logits()传入的logits为神经网络的输出,
                # shape为[batch_size, num_classes]
                # 传入的labels为一维向量,长度是batch_size,每一个值的取值为[0, num_classes),
                # 每一个值代表对应样本的类别
                
                ax1_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=ax1_out_train, labels=_y))
                ax2_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=ax2_out_train, labels=_y))
                main_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=main_out_train, labels=_y))
                loss = main_loss + 0.3 * ax1_loss + 0.3 * ax2_loss
                
                loss_dict[('ax1_loss_%s_%d' % ('gpu', i))] = ax1_loss
                loss_keys.append(('ax1_loss_%s_%d' % ('gpu', i)))
                loss_dict[('ax2_loss_%s_%d' % ('gpu', i))] = ax2_loss
                loss_keys.append(('ax2_loss_%s_%d' % ('gpu', i)))
                loss_dict[('main_loss_%s_%d' % ('gpu', i))] = main_loss
                loss_keys.append(('main_loss_%s_%d' % ('gpu', i)))

                grads = optimizer.compute_gradients(loss)
                tower_grads.append(grads)

                # Only first GPU compute accuracy
                if i == 0:
                    pred = tf.nn.softmax(main_out_test)
                    # Evaluate model (with test logits, for dropout to be disabled)
                    correct_prediction = tf.equal(tf.argmax(pred, -1), _y)
                    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
                
                reuse_vars = True
                    
        tower_grads = average_gradients(tower_grads)
        # Apply the gradients to adjust the shared variables.
        train_op = optimizer.apply_gradients(grads, global_step=global_step)
        
        # Start traininig
        config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            # summary writer
            summary = tf.summary.FileWriter(args.log_dir, sess.graph)
            summaries = []
            # add grad histogram op
            for grad, var in grads:
                if grad is not None:
                    summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad))
            # add trainabel variable gradients
            for var in tf.trainable_variables():
                summaries.append(tf.summary.histogram(var.op.name, var))
            # add loss summary
            for keys, val in loss_dict.items():
                summaries.append(tf.summary.scalar(keys, val))
            # add learning rate
            summaries.append(tf.summary.scalar('leraning_rate', lr))
            # add train accuracy 
            summaries.append(tf.summary.scalar("train_acc", accuracy))
            summary_op = tf.summary.merge(summaries)
            
            init = tf.global_variables_initializer()
            sess.run(init)
            
            saver = tf.train.Saver()

            # restore checkpoint if it exists
            could_load = load_ckpt(sess, saver, args.checkpoint_dir)
            
            # begin iteration
            for step in range(1, args.max_steps+1):
                image_batch, label_batch = sess.run(one_element)
                feed_dict={images: image_batch, labels: label_batch}
                start = time.time()
                _, summary_op_val, train_loss, train_acc = sess.run([train_op, summary_op, loss, accuracy], feed_dict=feed_dict)
                end = time.time()
                print("Current step: %d/%d Time: %.4f Train_loss: %.6f Train_acc: %.4f"%(step, args.max_steps, end - start, train_loss, train_acc))

                if step % args.summary_freq == 0:
                    summary.add_summary(summary_op_val, step)
                
                # save the checkpoint per save_freq
                if step % args.save_freq == 0:
                    save_ckpt(sess, saver, args.checkpoint_dir, global_step)
Пример #20
0
    def train_model(self, chpt_path=None):
        g = tf.Graph()
        with g.as_default():
            with tf.device('/cpu:0'):
                tf.compat.v1.random.set_random_seed(123)

                # Init global step
                self.global_step = tf.compat.v1.train.create_global_step()

                batch_queue = self.get_data_queue()
                opt = self.optimizer()

                # Calculate the gradients for each model tower.
                tower_grads = []
                loss = 0.

                with tf.compat.v1.variable_scope(
                        tf.compat.v1.get_variable_scope()):
                    for i in range(self.num_gpus):
                        with tf.device('/gpu:%d' % i):
                            with tf.name_scope('tower_{}'.format(i)) as scope:
                                loss_, grads_, layers_ = self.build_model(
                                    batch_queue, opt, scope, i)
                                loss += loss_ / self.num_gpus

                            tower_grads.append(grads_)

                grad = average_gradients(tower_grads)

                # Make summaries
                self.make_summaries(grad, layers_)

                # Apply the gradients to adjust the shared variables.
                print(
                    '========================================WD VARS==============================================='
                )
                wd_vars = get_variables_to_train(self.train_scopes)
                if self.excl_gamma_wd:
                    wd_vars = [v for v in wd_vars if 'gamma' not in v.op.name]
                if self.excl_beta_wd:
                    wd_vars = [v for v in wd_vars if 'beta' not in v.op.name]
                if self.excl_bias_wd:
                    wd_vars = [v for v in wd_vars if 'biases' not in v.op.name]
                print('WD variables: {}'.format([v.op.name for v in wd_vars]))
                print(
                    '=============================================================================================='
                )

                train_op = opt.apply_gradients(grad,
                                               global_step=self.global_step,
                                               decay_var_list=wd_vars)

                # Group all updates to into a single train op.
                train_op = control_flow_ops.with_dependencies([train_op], loss)

                # Create a saver.
                saver = tf.compat.v1.train.Saver(
                    tf.compat.v1.global_variables())
                init_fn = self.make_init_fn(chpt_path)

                # Build the summary operation from the last tower summaries.
                summary_op = tf.compat.v1.summary.merge(self.summaries)

                # Build an initialization operation to run below.
                init = tf.compat.v1.global_variables_initializer()

                # Start running operations on the Graph.
                sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
                    allow_soft_placement=True, log_device_placement=False),
                                            graph=g)
                sess.run(init)
                if init_fn:
                    init_fn(sess)

                summary_writer = tf.compat.v1.summary.FileWriter(
                    self.get_save_dir(), sess.graph)
                init_step = sess.run(self.global_step)
                print('Start training at step: {}'.format(init_step))
                for step in range(init_step, self.num_train_steps):

                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss])

                    duration = time.time() - start_time

                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'

                    if step % (self.num_train_steps // 2000) == 0:
                        num_examples_per_step = self.batch_size
                        examples_per_sec = num_examples_per_step / duration
                        sec_per_batch = duration
                        print(
                            '{}: step {}/{}, loss = {} ({} examples/sec; {} sec/batch)'
                            .format(datetime.now(), step, self.num_train_steps,
                                    loss_value, examples_per_sec,
                                    sec_per_batch))
                        sys.stdout.flush()

                    if step % (self.num_train_steps // 200) == 0:
                        print('Writing summaries...')
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)

                    # Save the model checkpoint periodically.
                    if step % (self.num_train_steps // 40) == 0 or (
                            step + 1) == self.num_train_steps:
                        checkpoint_path = os.path.join(self.get_save_dir(),
                                                       'model.ckpt')
                        print(
                            'Saving checkpoint to: {}'.format(checkpoint_path))
                        saver.save(sess, checkpoint_path, global_step=step)
Пример #21
0
def main(args, local_rank):

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    vocabs = dict()
    vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS])
    vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS])

    if args.world_size == 1 or (dist.get_rank() == 0):
        logger.info(args)
        for name in vocabs:
            logger.info("vocab %s, size %d, coverage %.3f", name,
                        vocabs[name].size, vocabs[name].coverage)

    set_seed(19940117)

    #device = torch.device('cpu')
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    if args.resume_ckpt:
        model = MatchingModel.from_pretrained(vocabs, args.resume_ckpt)
    else:
        model = MatchingModel.from_params(vocabs, args.layers, args.embed_dim,
                                          args.ff_embed_dim, args.num_heads,
                                          args.dropout, args.output_dim,
                                          args.bow)

    if args.world_size > 1:
        set_seed(19940117 + dist.get_rank())

    model = model.to(device)

    if args.resume_ckpt:
        dev_data = DataLoader(vocabs,
                              args.dev_data,
                              args.dev_batch_size,
                              addition=args.additional_negs)
        acc = validate(model, dev_data, device)
        logger.info("initialize from %s, initial acc %.2f", args.resume_ckpt,
                    acc)

    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     betas=(0.9, 0.98),
                     eps=1e-9)
    lr_schedule = get_linear_schedule_with_warmup(optimizer, args.warmup_steps,
                                                  args.total_train_steps)
    train_data = DataLoader(vocabs,
                            args.train_data,
                            args.per_gpu_train_batch_size,
                            worddrop=args.worddrop,
                            addition=args.additional_negs)
    global_step, step, epoch = 0, 0, 0
    tr_stat = Statistics()
    logger.info("start training")
    model.train()
    while global_step <= args.total_train_steps:
        for batch in train_data:
            batch = move_to_device(batch, device)
            loss, acc, bsz = model(batch['src_tokens'], batch['tgt_tokens'],
                                   args.label_smoothing)
            tr_stat.update({
                'loss': loss.item() * bsz,
                'nsamples': bsz,
                'acc': acc * bsz
            })
            tr_stat.step()
            loss.backward()

            step += 1
            if not (step % args.gradient_accumulation_steps
                    == -1 % args.gradient_accumulation_steps):
                continue

            if args.world_size > 1:
                average_gradients(model)

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_schedule.step()
            optimizer.zero_grad()
            global_step += 1

            if args.world_size == 1 or (dist.get_rank() == 0):
                if global_step % args.print_every == -1 % args.print_every:
                    logger.info("epoch %d, step %d, loss %.3f, acc %.3f",
                                epoch, global_step,
                                tr_stat['loss'] / tr_stat['nsamples'],
                                tr_stat['acc'] / tr_stat['nsamples'])
                    tr_stat = Statistics()
                if global_step > args.warmup_steps and global_step % args.eval_every == -1 % args.eval_every:
                    dev_data = DataLoader(vocabs,
                                          args.dev_data,
                                          args.dev_batch_size,
                                          addition=args.additional_negs)
                    acc = validate(model, dev_data, device)
                    logger.info("epoch %d, step %d, dev, dev acc %.2f", epoch,
                                global_step, acc)
                    save_path = '%s/epoch%d_batch%d_acc%.2f' % (
                        args.ckpt, epoch, global_step, acc)
                    model.save(args, save_path)
                    model.train()
            if global_step > args.total_train_steps:
                break
        epoch += 1
    logger.info('rank %d, finish training after %d steps', local_rank,
                global_step)
Пример #22
0
    def step(self):
        # output
        if self.with_modal:
            output, _ = self.model(torch.cat([self.rgb, self.modal], dim=1),
                                   self.visible_mask4)
        else:
            output, _ = self.model(self.rgb, self.visible_mask3)
        if output.shape[2] != self.rgb.shape[2]:
            output = nn.functional.interpolate(output,
                                               size=self.rgb.shape[2:4],
                                               mode="bilinear",
                                               align_corners=True)

        # discriminator loss
        dis_input_real = self.rgb_gt
        dis_input_fake = output.detach()
        if self.with_modal:
            dis_real, _ = self.netD(
                torch.cat([dis_input_real, self.modal], dim=1))
            dis_fake, _ = self.netD(
                torch.cat([dis_input_fake, self.modal], dim=1))
        else:
            dis_real, _ = self.netD(dis_input_real)
            dis_fake, _ = self.netD(dis_input_fake)
        dis_real_loss = self.gan_criterion(dis_real, True,
                                           True) / self.world_size
        dis_fake_loss = self.gan_criterion(dis_fake, False,
                                           True) / self.world_size
        dis_loss = (dis_real_loss + dis_fake_loss) / 2

        # generator adversarial loss
        gen_loss = 0
        gen_input_fake = output
        if self.with_modal:
            gen_fake, _ = self.netD(
                torch.cat([gen_input_fake, self.modal], dim=1))
        else:
            gen_fake, _ = self.netD(gen_input_fake)
        gen_gan_loss = self.gan_criterion(gen_fake, True, False) * \
            self.params['adv_loss_weight'] / self.world_size
        gen_loss += gen_gan_loss

        # other losses
        loss_dict = self.criterion(self.rgb, self.visible_mask3, output,
                                   self.rgb_gt)
        for k in loss_dict.keys():
            loss_dict[k] /= self.world_size
        for key, coef in self.params['lambda_dict'].items():
            value = coef * loss_dict[key]
            gen_loss += value

        # create loss dict
        loss_dict['dis'] = dis_loss
        loss_dict['adv'] = gen_gan_loss

        # update
        self.optimD.zero_grad()
        dis_loss.backward()
        utils.average_gradients(self.netD)
        self.optimD.step()

        self.optim.zero_grad()
        gen_loss.backward()
        utils.average_gradients(self.model)
        self.optim.step()

        return loss_dict
Пример #23
0
    def train_model(self, chpt_path):
        print('Restoring from: {}'.format(chpt_path))
        g = tf.Graph()
        with g.as_default():
            with tf.device('/cpu:0'):
                # Init global step
                self.global_step = tf.train.create_global_step()

                batch_queue = self.get_data_queue()
                opt = self.optimizer()

                # Calculate the gradients for each model tower.
                tower_grads = []
                loss = None
                layers = None
                with tf.variable_scope(tf.get_variable_scope()):
                    for i in range(self.num_gpus):
                        with tf.device('/gpu:%d' % i):
                            with tf.name_scope('tower_{}'.format(i)) as scope:
                                loss, grads, layers = self.build_model(
                                    batch_queue, i, opt, scope)
                                tower_grads.append(grads)
                grad = average_gradients(tower_grads)

                # Make summaries
                self.make_summaries(grad, layers)

                # Apply the gradients to adjust the shared variables.
                apply_gradient_op = opt.apply_gradients(
                    grad, global_step=self.global_step)

                if self.var_avg:
                    # Track the moving averages of all trainable variables.
                    variable_averages = tf.train.ExponentialMovingAverage(
                        self.moving_avgs_decay, self.global_step)
                    variables_averages_op = variable_averages.apply(
                        tf.trainable_variables())

                    # Group all updates to into a single train op.
                    apply_gradient_op = tf.group(apply_gradient_op,
                                                 variables_averages_op)

                train_op = control_flow_ops.with_dependencies(
                    [apply_gradient_op], loss)

                # Create a saver.
                saver = tf.train.Saver(tf.global_variables())
                init_fn = self.make_init_fn(chpt_path)

                # Build the summary operation from the last tower summaries.
                summary_op = tf.summary.merge(self.summaries)

                # Build an initialization operation to run below.
                init = tf.global_variables_initializer()

                # Start running operations on the Graph.
                sess = tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True, log_device_placement=False),
                                  graph=g)
                sess.run(init)
                prev_ckpt = get_checkpoint_path(self.get_save_dir())
                if prev_ckpt:
                    print('Restoring from previous checkpoint: {}'.format(
                        prev_ckpt))
                    saver.restore(sess, prev_ckpt)
                elif init_fn:
                    init_fn(sess)

                summary_writer = tf.summary.FileWriter(self.get_save_dir(),
                                                       sess.graph)
                init_step = sess.run(self.global_step)
                print('Start training at step: {}'.format(init_step))
                for step in range(init_step, self.num_train_steps):

                    start_time = time.time()
                    _, loss_value = sess.run([train_op, loss])
                    duration = time.time() - start_time

                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'

                    if step % 50 == 0:
                        num_examples_per_step = self.model.batch_size * self.num_gpus
                        examples_per_sec = num_examples_per_step / duration
                        sec_per_batch = duration / self.num_gpus
                        print(
                            '{}: step {}/{}, loss = {} ({} examples/sec; {} sec/batch)'
                            .format(datetime.now(), step, self.num_train_steps,
                                    loss_value, examples_per_sec,
                                    sec_per_batch))
                        sys.stdout.flush()

                    if step % (self.num_train_steps /
                               self.num_summary_steps) == 0:
                        print('Writing summaries...')
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)

                    # Save the model checkpoint periodically.
                    if step % (self.num_train_steps / self.num_summary_steps *
                               4) == 0 or (step + 1) == self.num_train_steps:
                        checkpoint_path = os.path.join(self.get_save_dir(),
                                                       'model.ckpt')
                        print(
                            'Saving checkpoint to: {}'.format(checkpoint_path))
                        saver.save(sess, checkpoint_path, global_step=step)
Пример #24
0
    def build_train_model(self, test=True, reuse=None):
        """Build model for training. """
        logging.info('Build train model.')
        self.prepare_training()

        with self.graph.as_default():
            acc_list, loss_list, gv_list = [], [], []
            cache = {}
            load = dict([(d, 0) for d in self._devices])
            for i, (X, Y, device) in enumerate(
                    zip(self.src_pls, self.dst_pls, self._devices)):

                def daisy_chain_getter(getter, name, *args, **kwargs):
                    """Get a variable and cache in a daisy chain."""
                    device_var_key = (device, name)
                    if device_var_key in cache:
                        # if we have the variable on the correct device, return it.
                        return cache[device_var_key]
                    if name in cache:
                        # if we have it on a different device, copy it from the last device
                        v = tf.identity(cache[name])
                    else:
                        var = getter(name, *args, **kwargs)
                        v = tf.identity(var._ref())  # pylint: disable=protected-access
                    # update the cache
                    cache[name] = v
                    cache[device_var_key] = v
                    return v

                def balanced_device_setter(op):
                    """Balance variables to all devices."""
                    if op.type in {'Variable', 'VariableV2', 'VarHandleOp'}:
                        # return self._sync_device
                        min_load = min(load.values())
                        min_load_devices = [
                            d for d in load if load[d] == min_load
                        ]
                        chosen_device = random.choice(min_load_devices)
                        load[chosen_device] += op.outputs[0].get_shape(
                        ).num_elements()
                        return chosen_device
                    return device

                def identity_device_setter(op):
                    return device

                device_setter = balanced_device_setter

                with tf.variable_scope(tf.get_variable_scope(),
                                       initializer=self._initializer,
                                       custom_getter=daisy_chain_getter,
                                       reuse=reuse):
                    with tf.device(device_setter):
                        logging.info('Build model on %s.' % device)
                        encoder_output = self.encoder(
                            X,
                            is_training=True,
                            reuse=i > 0 or None,
                            encoder_scope=self.encoder_scope)
                        decoder_output = self.decoder(
                            shift_right(Y),
                            encoder_output,
                            is_training=True,
                            reuse=i > 0 or None,
                            decoder_scope=self.decoder_scope)
                        acc, loss = self.train_output(
                            decoder_output,
                            Y,
                            reuse=i > 0 or None,
                            decoder_scope=self.decoder_scope)
                        acc_list.append(acc)
                        loss_list.append(loss)

                        var_list = tf.trainable_variables()
                        if self._config.train.var_filter:
                            var_list = [
                                v for v in var_list if re.match(
                                    self._config.train.var_filter, v.name)
                            ]
                        gv_list.append(
                            self._optimizer.compute_gradients(
                                loss, var_list=var_list))

            self.accuracy = tf.reduce_mean(acc_list)
            self.loss = tf.reduce_mean(loss_list)

            # Clip gradients and then apply.
            grads_and_vars = average_gradients(gv_list)
            avg_abs_grads = tf.reduce_mean(tf.abs(grads_and_vars[0]))

            if self._config.train.grads_clip > 0:
                grads, self.grads_norm = tf.clip_by_global_norm(
                    [gv[0] for gv in grads_and_vars],
                    clip_norm=self._config.train.grads_clip)
                grads_and_vars = zip(grads, [gv[1] for gv in grads_and_vars])
            else:
                self.grads_norm = tf.global_norm(
                    [gv[0] for gv in grads_and_vars])

            self.train_op = self._optimizer.apply_gradients(
                grads_and_vars, global_step=self.global_step)

            # Summaries
            tf.summary.scalar('acc', self.accuracy)
            tf.summary.scalar('loss', self.loss)
            tf.summary.scalar('learning_rate', self.learning_rate)
            tf.summary.scalar('grads_norm', self.grads_norm)
            tf.summary.scalar('avg_abs_grads', avg_abs_grads)
            self.summary_op = tf.summary.merge_all()

            self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                        max_to_keep=20)

        # We may want to test the model during training.
        if test:
            self.build_test_model(reuse=True)
Пример #25
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
Пример #26
0
                losses_plain.append(loss_plain_tower)
                losses_reg.append(loss_reg_tower)
                regs_rb.append(reg_rb_tower)
                regs_db.append(reg_db_tower)
                err_rates.append(err_rate_tower)

                # Reuse variables for this tower.
                tf.get_variable_scope().reuse_variables()

    loss_plain = tf.reduce_mean(tf.stack(losses_plain))
    loss_reg = tf.reduce_mean(tf.stack(losses_reg))
    reg_rb = tf.reduce_mean(tf.stack(regs_rb))
    reg_db = tf.reduce_mean(tf.stack(regs_db))
    err_rate = tf.reduce_mean(tf.stack(err_rates))

    grads_vars = utils.average_gradients(tower_grads)
    train_step_loss_reg = optimizer.apply_gradients(grads_vars,
                                                    name='train_step')

    # Separate forward pass graph for Cleverhans wrapper (for PGD attack) placed on the last GPU
    logits_all_gpus = forward_pass_cleverhans(x_tf)

    # Model saver
    saver = tf.train.Saver()
    # GPU settings
    gpu_options = tf.GPUOptions(visible_device_list=str(hps.gpus)[1:-1],
                                per_process_gpu_memory_fraction=hps.gpu_memory)
    config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)

with tf.Session(graph=graph, config=config) as sess:
    with graph.as_default(), tf.device('/gpu:0'):
def main():

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        "--device_list",
        help="Comma seperatted device IDs to run benchmark on.",
        type=str,
        default="0")
    parser.add_argument("--batch_size_per_gpu",
                        help="Batch size on each GPU",
                        type=int,
                        default=32)
    parser.add_argument("--num_classes",
                        help="Number of classes",
                        type=int,
                        default=100)
    parser.add_argument("--num_warmup",
                        help="Number of warm up iterations.",
                        type=int,
                        default=50)
    parser.add_argument("--num_iterations",
                        help="Number of benchmark iterations.",
                        type=int,
                        default=200)
    config = parser.parse_args()

    config.device_list = list(map(int, config.device_list.split(",")))
    config.gpu_count = len(config.device_list)

    x = np.random.rand(config.gpu_count * config.batch_size_per_gpu, 224, 224,
                       3)
    y = np.random.randint(config.num_classes,
                          size=(config.gpu_count * config.batch_size_per_gpu))

    with tf.device("/cpu:0"):
        image = tf.placeholder(tf.float32,
                               shape=(config.gpu_count *
                                      config.batch_size_per_gpu, 224, 224, 3))
        label = tf.placeholder(tf.int32,
                               shape=(config.gpu_count *
                                      config.batch_size_per_gpu))

    list_grads_and_vars = []

    # Map
    for split_id, device_id in enumerate(config.device_list):
        with tf.device(
                utils.assign_to_device("/gpu:{}".format(device_id),
                                       ps_device="/cpu:0")):

            # Split input data across multiple devices
            images_batch, lables_batch = utils.batch_split(
                (image, label), split_id, config.batch_size_per_gpu)

            outputs = net.simple_net(images_batch, config.batch_size_per_gpu,
                                     config.num_classes)

            loss = tf.losses.sparse_softmax_cross_entropy(labels=lables_batch,
                                                          logits=outputs)

            optimizer = tf.train.MomentumOptimizer(learning_rate=0.001,
                                                   momentum=0.9)

            list_grads_and_vars.append(optimizer.compute_gradients(loss))

    ave_grads_and_vars = utils.average_gradients(list_grads_and_vars)

    minimize_op = optimizer.apply_gradients(ave_grads_and_vars)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90,
                                allow_growth=True)

    session_config = tf.ConfigProto(allow_soft_placement=False,
                                    log_device_placement=False,
                                    gpu_options=gpu_options)

    with tf.Session(config=session_config) as sess:
        sess.run(tf.global_variables_initializer())

        print("Warm up started.")
        for i_iter in range(config.num_warmup):
            sess.run(minimize_op, feed_dict={image: x, label: y})
        print("Warm up finished.")

        start_time = time.time()
        for i_iter in range(config.num_iterations):
            print("\rIteration: " + str(i_iter), end="")
            sys.stdout.flush()
            sess.run(minimize_op, feed_dict={image: x, label: y})
        end_time = time.time()

        total_time = end_time - start_time
        total_num_images = (config.gpu_count * config.batch_size_per_gpu *
                            config.num_iterations)

        print("\nTotal time spend: " + str(total_time) + " secs.")
        print("Average Speed: " + str(total_num_images / total_time) +
              " images/sec.")
Пример #28
0
                    correct_prediction = tf.equal(tf.argmax(model, 1), y)
                    accuracy = tf.reduce_mean(
                        tf.cast(correct_prediction, tf.float32))
                    towers_acc.append(accuracy)

                    correct_prediction5 = tf.nn.in_top_k(model, y, k=5)
                    accuracy5 = tf.reduce_mean(
                        tf.cast(correct_prediction5, tf.float32))
                    towers_acc5.append(accuracy5)

                    tf.get_variable_scope().reuse_variables()
        pass

    # aggregate all gradients
    grads = average_gradients(towers_grad)
    acc1 = tf.reduce_mean(towers_acc)
    acc5 = tf.reduce_mean(towers_acc5)
    train_step = train_step.apply_gradients(grads, global_step=global_step)
    lss = tf.reduce_mean(towers_loss)

    # Create a summary to monitor cost tensor
    tf.summary.scalar("loss", lss)
    # Create a summary to monitor accuracy tensor
    tf.summary.scalar("accuracy-top1", acc1)
    tf.summary.scalar("accuracy-top5", acc5)

    # Merge all summaries into a single op
    merged_summary_op = tf.summary.merge_all()

    # Track the moving averages of all trainable variables.
Пример #29
0
def main():

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    logger = TensorBoardOutputFormat(logdir)

    config = tf.ConfigProto()

    sess = tf.Session(config=config)
    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cubes':
        dataset = Cubes(cond_idx=FLAGS.cond_idx)
        test_dataset = dataset

        if FLAGS.cond_idx == 0:
            label_size = 2
        elif FLAGS.cond_idx == 1:
            label_size = 1
        elif FLAGS.cond_idx == 2:
            label_size = 3
        elif FLAGS.cond_idx == 3:
            label_size = 20

        LABEL = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, label_size), dtype=tf.float32)
    elif FLAGS.dataset == 'color':
        dataset = CubesColor()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 301), dtype=tf.float32)
        label_size = 301
    elif FLAGS.dataset == 'pos':
        dataset = CubesPos()
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        label_size = 2
    elif FLAGS.dataset == "pairs":
        dataset = Pairs(cond_idx=0)
        test_dataset = dataset
        LABEL = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 6), dtype=tf.float32)
        label_size = 6
    elif FLAGS.dataset == "continual":
        dataset = CubesContinual()
        test_dataset = dataset

        if FLAGS.prelearn_model_shape:
            LABEL = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 20), dtype=tf.float32)
            label_size = 20
        else:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

    elif FLAGS.dataset == "cross":
        dataset = CubesCrossProduct(FLAGS.ratio, cond_size=FLAGS.cond_size, cond_pos=FLAGS.cond_pos, joint_baseline=FLAGS.joint_baseline)
        test_dataset = dataset

        if FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            label_size = 1
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            label_size = 2

        if FLAGS.joint_baseline:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            label_size = 3

    elif FLAGS.dataset == 'celeba':
        dataset = CelebA(cond_idx=FLAGS.celeba_cond_idx)
        test_dataset = dataset
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        model = ResNet128(
            num_channels=channel_num,
            num_filters=64,
            classes=2)

    if FLAGS.joint_baseline:
        # Other stuff for joint model
        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.99, beta2=0.999)

        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)
        NOISE = tf.placeholder(shape=(None, 128), dtype=tf.float32)
        HIER_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)

        channel_num = 3

        model = CubesNetGen(num_channels=channel_num, label_size=label_size)
        weights = model.construct_weights('context_0')
        output = model.forward(NOISE, weights, reuse=False, label=LABEL)
        print(output.get_shape())
        mse_loss = tf.reduce_mean(tf.square(output - X))
        gvs = optimizer.compute_gradients(mse_loss)
        train_op = optimizer.apply_gradients(gvs)
        gvs = [(k, v) for (k, v) in gvs if k is not None]

        target_vars = {}
        target_vars['train_op'] = train_op
        target_vars['X'] = X
        target_vars['X_NOISE'] = X_NOISE
        target_vars['ATTENTION_MASK'] = ATTENTION_MASK
        target_vars['eps_begin'] = tf.zeros(1)
        target_vars['gvs'] = gvs
        target_vars['energy_pos'] = tf.zeros(1)
        target_vars['energy_neg'] = tf.zeros(1)
        target_vars['loss_energy'] = tf.zeros(1)
        target_vars['loss_ml'] = tf.zeros(1)
        target_vars['total_loss'] = mse_loss
        target_vars['attention_mask'] = tf.zeros(1)
        target_vars['attention_grad'] = tf.zeros(1)
        target_vars['x_off'] = tf.reduce_mean(tf.abs(output - X))
        target_vars['x_mod'] = tf.zeros(1)
        target_vars['x_grad'] = tf.zeros(1)
        target_vars['NOISE'] = NOISE
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['HIER_LABEL'] = HIER_LABEL

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)
    else:
        print("label size here ", label_size)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64, 3), dtype=tf.float32)
        HEIR_LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        ATTENTION_MASK = tf.placeholder(shape=(None, 64, 64, FLAGS.cond_func), dtype=tf.float32)

        if FLAGS.dataset != "celeba":
            model = CubesNet(num_channels=channel_num, label_size=label_size)

        heir_model = HeirNet(num_channels=FLAGS.cond_func)

        models_pretrain = []
        if FLAGS.prelearn_model:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label)
            weights = model_prelearn.construct_weights('context_1')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp)
            if (FLAGS.prelearn_iter != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
                v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        if FLAGS.prelearn_model_shape:
            model_prelearn = CubesNet(num_channels=channel_num, label_size=FLAGS.prelearn_label_shape)
            weights = model_prelearn.construct_weights('context_2')
            LABEL_PRELEARN = tf.placeholder(shape=(None, FLAGS.prelearn_label_shape), dtype=tf.float32)
            models_pretrain.append((model_prelearn, weights, LABEL_PRELEARN))

            cubes_logdir = osp.join(FLAGS.logdir, FLAGS.prelearn_exp_shape)
            if (FLAGS.prelearn_iter_shape != -1 or not FLAGS.train):
                model_file = osp.join(cubes_logdir, 'model_{}'.format(FLAGS.prelearn_iter_shape))
                resume_itr = FLAGS.resume_iter
                # saver.restore(sess, model_file)

                v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
                v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
                saver = tf.train.Saver(v_map)
                saver.restore(sess, model_file)

        print("Done loading...")

        data_loader = DataLoader(
            dataset,
            batch_size=FLAGS.batch_size,
            num_workers=FLAGS.data_workers,
            drop_last=True,
            shuffle=True)

        batch_size = FLAGS.batch_size

        weights = model.construct_weights('context_0')

        if FLAGS.heir_mask:
            weights = heir_model.construct_weights('heir_0', weights=weights)

        Y = tf.placeholder(shape=(None), dtype=tf.int32)

        # Varibles to run in training

        X_SPLIT = tf.split(X, FLAGS.num_gpus)
        X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
        LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
        LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
        LABEL_SPLIT_INIT = list(LABEL_SPLIT)
        attention_mask = ATTENTION_MASK
        tower_grads = []
        tower_gen_grads = []
        x_mod_list = []

        optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.99)

        for j in range(FLAGS.num_gpus):

            x_mod = X_SPLIT[j]
            if FLAGS.comb_mask:
                steps = tf.constant(0)
                c = lambda i, x: tf.less(i, FLAGS.num_steps)

                def langevin_attention_step(counter, attention_mask):
                    attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)
                    energy_noise = energy_start = model.forward(
                                x_mod,
                                weights,
                                attention_mask,
                                label=LABEL_SPLIT[j],
                                reuse=True,
                                stop_at_grad=False,
                                stop_batch=True)

                    if FLAGS.heir_mask:
                        energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                        energy_noise = energy_noise + energy_heir

                    attention_grad = tf.gradients(
                        FLAGS.temperature * energy_noise, [attention_mask])[0]
                    energy_noise_old = energy_noise

                    # Clip gradient norm for now
                    attention_mask = attention_mask - (FLAGS.attention_lr) * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                    counter = counter + 1

                    return counter, attention_mask

                steps, attention_mask = tf.while_loop(c, langevin_attention_step, (steps, attention_mask))

                # attention_mask = tf.Print(attention_mask, [attention_mask])

                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        tf.stop_gradient(attention_mask),
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            else:
                energy_pos = model.forward(
                        X_SPLIT[j],
                        weights,
                        attention_mask,
                        label=LABEL_POS_SPLIT[j],
                        stop_at_grad=False)

                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_pos = energy_heir + energy_pos

            print("Building graph...")
            x_mod = x_orig = X_NOISE_SPLIT[j]

            x_grads = []

            loss_energys = []

            eps_begin = tf.zeros(1)

            steps = tf.constant(0)
            c_cond = lambda i, x, y: tf.less(i, FLAGS.num_steps)

            def langevin_step(counter, x_mod, attention_mask):

                lr = FLAGS.step_lr

                x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                                 mean=0.0,
                                                 stddev=0.001 * FLAGS.rescale * FLAGS.noise_scale)
                attention_mask = attention_mask + tf.random_normal(tf.shape(attention_mask), mean=0.0, stddev=0.01)

                energy_noise = model.forward(
                            x_mod,
                            weights,
                            attention_mask,
                            label=LABEL_SPLIT[j],
                            reuse=True,
                            stop_at_grad=False,
                            stop_batch=True)

                if FLAGS.prelearn_model:
                    for m_i, w_i, l_i in models_pretrain:
                        energy_noise = energy_noise + m_i.forward(
                                    x_mod,
                                    w_i,
                                    attention_mask,
                                    label=l_i,
                                    reuse=True,
                                    stop_at_grad=False,
                                    stop_batch=True)


                if FLAGS.heir_mask:
                    energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                    energy_noise = energy_heir + energy_noise

                x_grad, attention_grad = tf.gradients(
                    FLAGS.temperature * energy_noise, [x_mod, attention_mask])

                if not FLAGS.comb_mask:
                    attention_grad = tf.zeros(1)
                energy_noise_old = energy_noise

                if FLAGS.proj_norm != 0.0:
                    if FLAGS.proj_norm_type == 'l2':
                        x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                    elif FLAGS.proj_norm_type == 'li':
                        x_grad = tf.clip_by_value(
                            x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                    else:
                        print("Other types of projection are not supported!!!")
                        assert False

                # Clip gradient norm for now
                x_last = x_mod - (lr) * x_grad

                if FLAGS.comb_mask:
                    attention_mask = attention_mask - FLAGS.attention_lr * attention_grad
                    attention_mask = tf.layers.average_pooling2d(attention_mask, (3, 3), 1, padding='SAME')
                    attention_mask = tf.stop_gradient(attention_mask)

                x_mod = x_last
                x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

                counter = counter + 1

                return counter, x_mod, attention_mask


            steps, x_mod, attention_mask = tf.while_loop(c_cond, langevin_step, (steps, x_mod, attention_mask))

            attention_mask = tf.stop_gradient(attention_mask)
            # attention_mask = tf.Print(attention_mask, [attention_mask])

            energy_eval = model.forward(x_mod, weights, attention_mask, label=LABEL_SPLIT[j],
                                        stop_at_grad=False, reuse=True)
            x_grad, attention_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod, attention_mask])
            x_grads.append(x_grad)

            energy_neg = model.forward(
                    tf.stop_gradient(x_mod),
                    weights,
                    tf.stop_gradient(attention_mask),
                    label=LABEL_SPLIT[j],
                    stop_at_grad=False,
                    reuse=True)

            if FLAGS.heir_mask:
                energy_heir = 1.00 * heir_model.forward(attention_mask, weights, label=HEIR_LABEL)
                energy_neg = energy_heir + energy_neg


            temp = FLAGS.temperature

            x_off = tf.reduce_mean(
                tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

            loss_energy = model.forward(
                x_mod,
                weights,
                attention_mask,
                reuse=True,
                label=LABEL,
                stop_grad=True)

            print("Finished processing loop construction ...")

            target_vars = {}

            if FLAGS.antialias:
                antialias = tf.tile(stride_3, (1, 1, tf.shape(x_mod)[3], tf.shape(x_mod)[3]))
                inp = tf.nn.conv2d(x_mod, antialias, [1, 2, 2, 1], padding='SAME')

            test_x_mod = x_mod

            if FLAGS.cclass or FLAGS.model_cclass:
                label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
                label_prob = label_sum / tf.reduce_sum(label_sum)
                label_ent = -tf.reduce_sum(label_prob *
                                           tf.math.log(label_prob + 1e-7))
            else:
                label_ent = tf.zeros(1)

            target_vars['label_ent'] = label_ent

            if FLAGS.train:
                if FLAGS.objective == 'logsumexp':
                    pos_term = temp * energy_pos
                    energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                    coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                    norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'cd':
                    pos_loss = tf.reduce_mean(temp * energy_pos)
                    neg_loss = -tf.reduce_mean(temp * energy_neg)
                    loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
                elif FLAGS.objective == 'softplus':
                    loss_ml = FLAGS.ml_coeff * \
                        tf.nn.softplus(temp * (energy_pos - energy_neg))

                loss_total = tf.reduce_mean(loss_ml)

                if not FLAGS.zero_kl:
                    loss_total = loss_total + tf.reduce_mean(loss_energy)

                loss_total = loss_total + \
                    FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

                print("Started gradient computation...")
                gvs = optimizer.compute_gradients(loss_total)
                gvs = [(k, v) for (k, v) in gvs if k is not None]

                print("Applying gradients...")

                tower_grads.append(gvs)

                print("Finished applying gradients.")

                target_vars['loss_ml'] = loss_ml
                target_vars['total_loss'] = loss_total
                target_vars['loss_energy'] = loss_energy
                target_vars['weights'] = weights
                target_vars['gvs'] = gvs

            target_vars['X'] = X
            target_vars['Y'] = Y
            target_vars['LABEL'] = LABEL
            target_vars['HIER_LABEL'] = HEIR_LABEL
            target_vars['LABEL_POS'] = LABEL_POS
            target_vars['X_NOISE'] = X_NOISE
            target_vars['energy_pos'] = energy_pos
            target_vars['attention_grad'] = attention_grad

            if len(x_grads) >= 1:
                target_vars['x_grad'] = x_grads[-1]
                target_vars['x_grad_first'] = x_grads[0]
            else:
                target_vars['x_grad'] = tf.zeros(1)
                target_vars['x_grad_first'] = tf.zeros(1)

            target_vars['x_mod'] = x_mod
            target_vars['x_off'] = x_off
            target_vars['temp'] = temp
            target_vars['energy_neg'] = energy_neg
            target_vars['test_x_mod'] = test_x_mod
            target_vars['eps_begin'] = eps_begin
            target_vars['ATTENTION_MASK'] = ATTENTION_MASK
            target_vars['models_pretrain'] = models_pretrain
            if FLAGS.comb_mask:
                target_vars['attention_mask'] = tf.nn.softmax(attention_mask)
            else:
                target_vars['attention_mask'] = tf.zeros(1)

        if FLAGS.train:
            grads = average_gradients(tower_grads)
            train_op = optimizer.apply_gradients(grads)
            target_vars['train_op'] = train_op

    # sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(
        max_to_keep=30, keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train):
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        train(target_vars, saver, sess,
              logger, data_loader, resume_itr,
              logdir)

    test(target_vars, saver, sess, logger, data_loader)
def train(local_save_dir):
    log = os.path.join(local_save_dir, 'log')
    if not os.path.exists(log):
        os.makedirs(log)
    logger = Logger(log + "/log", level=FLAGS.logger_level)

    with tf.device('cpu:0'):
        data_iterator, data_init_op, num_batch = datasets.data_loader(
            is_training=True)
        data = data_iterator.get_next()

        data_split = [{} for _ in range(FLAGS.gpu_num)]
        for k, t in data.items():
            t_split = tf.split(t, FLAGS.gpu_num, axis=0)
            for i, t_small in enumerate(t_split):
                data_split[i][k] = t_small

        optimizer = tf.train.MomentumOptimizer(FLAGS.base_lr, 0.9)

        grads = []
        display_losses = []
        for i in range(FLAGS.gpu_num):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%d' % i):
                    model_fn = get_model_fn()
                    model = model_fn(data_split[i],
                                     is_training=True)

                    grads_sub = []
                    for d in model.compute_gradients_losses:
                        grads_sub += optimizer.compute_gradients(
                            loss=d['value'], var_list=d['var_list'])
                    grads.append(grads_sub)

                display_losses += model.display_losses

        grads = utils.average_gradients(grads)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.apply_gradients(grads)

        var_init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())

        saver = tf.train.Saver(max_to_keep=5)

        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))
        sess.run(var_init_op)
        sess.run(data_init_op)

        print(tf.trainable_variables())

        ckpt = os.path.join(local_save_dir, 'checkpoint')
        load_model_status, global_step = utils.load(sess, saver, ckpt)
        if load_model_status:
            iter_num = global_step
            start_epoch = global_step // num_batch
            print("[*] Model restore success!")
        else:
            iter_num = 0
            start_epoch = 0
            print("[*] Not find pretrained model!")

        start = time.time()
        for epoch_id in range(start_epoch, FLAGS.epochs):
            for batch_id in range(num_batch):
                _, losses_eval = sess.run([train_op, display_losses])

                end = time.time()

                losses_dict = {}
                for d in losses_eval:
                    if d['name'] in losses_dict.keys():
                        losses_dict[d['name']] += [d['value']]
                    else:
                        losses_dict[d['name']] = [d['value']]

                log = "Epoch: [%2d] [%4d/%4d] time: %s | " % (
                    epoch_id+1, batch_id+1, num_batch,
                    str(timedelta(seconds=end-start))[0:10])
                for k, v in losses_dict.items():
                    k = k.decode("utf-8")
                    log += "%s: %.6f " % (k, np.mean(v))
                logger.logger.info(log)
                iter_num += 1

            logger.logger.info(log)

            if np.mod(epoch_id + 1, FLAGS.save_every_epoch) == 0:
                utils.save(sess, saver, iter_num, ckpt)
            if np.mod(epoch_id + 1, FLAGS.eval_every_epoch) == 0:
                evaluation(local_save_dir, sess, logger)

        print("[*] Finish training.")