Exemplo n.º 1
0
def train(epoch, model, train_loader, test_loader, optimizer, device, f):
    model.train()
    train_loss = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device).view(-1, 784)
        optimizer.zero_grad()
        px_logit, variational_params, latent_samples = model(data)
        loss = loss_function(data, targets, px_logit, variational_params,
                             latent_samples)
        loss['optimization_loss'].backward()
        train_loss += loss['optimization_loss'].item()
        optimizer.step()

    model.eval()
    with torch.no_grad():

        data_eval = train_loader.dataset.data.view(-1, 784)[np.random.choice(
            50000, 10000)].to(device) / 255.0
        recon_batch_eval, var_eval, lat_eval = model(data_eval)
        loss_eval = loss_function(data_eval, targets, recon_batch_eval,
                                  var_eval, lat_eval)

        data_test = test_loader.dataset.data.view(-1, 784).to(device) / 255.0
        recon_batch_test, var_test, lat_test = model(data_test)
        loss_test = loss_function(data_test, targets, recon_batch_test,
                                  var_test, lat_test)

        a, b, c, d = -loss_eval['nent'] / len(data_eval), loss_eval[
            'optimization_loss'] / len(data_eval), -loss_test['nent'] / len(
                data_test), loss_test['optimization_loss'] / len(data_test)
        e = test_acc(model, test_loader, device)
        string = ('{:>10s},{:>10s},{:>10s},{:>10s},{:>10s},{:>10s}'.format(
            'tr_ent', 'tr_loss', 't_ent', 't_loss', 't_acc', 'epoch'))
        stream_print(f, string, epoch == 1)
        string = ('{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10d}'.format(
            a, b, c, d, e, epoch))
        stream_print(f, string)

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

    n = min(data.size(0), 8)
    for i in range(10):
        comparison = torch.cat([
            data.view(-1, 1, 28, 28)[:n], px_logit[i][:n].view(-1, 1, 28, 28)
        ])
        save_image(comparison.cpu(),
                   'logs/reconstruction_' + str(epoch) + '_' + str(i) + '.png',
                   nrow=n)
Exemplo n.º 2
0
def eval(model, valid_loader, best_loss, count):

    model.eval()
    epoch_loss = 0
    valid_loader.dataset.ng_sample()  # negative sampling
    for user, item, label, noisy_or_not in valid_loader:
        user = user.cuda()
        item = item.cuda()
        label = label.float().cuda()

        prediction = model(user, item)
        loss = loss_function(prediction, label, drop_rate_schedule(count))
        epoch_loss += loss.detach()
    print("################### EVAL ######################")
    print("Eval loss:{}".format(epoch_loss))
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        if args.out:
            if not os.path.exists(model_path):
                os.mkdir(model_path)
            torch.save(
                model,
                '{}{}_{}-{}.pth'.format(model_path, args.model, args.drop_rate,
                                        args.num_gradual))
    return best_loss
    def train_step(enc_input, dec_target):

        with tf.GradientTape() as tape:
            # enc_output (batch_size, enc_len, enc_unit)
            # enc_hidden (batch_size, enc_unit)
            enc_output, enc_hidden = model.encoder(enc_input)

            # 第一个decoder输入 开始标签
            # dec_input (batch_size, 1)
            dec_input = tf.expand_dims([start_index] * batch_size, 1)

            # 第一个隐藏层输入
            # dec_hidden (batch_size, enc_unit)
            dec_hidden = enc_hidden
            # 逐个预测序列
            # predictions (batch_size, dec_len-1, vocab_size)
            predictions, _ = model(dec_input, dec_hidden, enc_output,
                                   dec_target)

            _batch_loss = loss_function(dec_target[:, 1:], predictions,
                                        vocab.pad_token_idx)

        variables = model.trainable_variables
        gradients = tape.gradient(_batch_loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

        return _batch_loss
Exemplo n.º 4
0
def train_step(img_tensor, target, tokenizer, loss_object, validation=False):
    """Training step as tf.function to allow for gradient updates in tensorflow.

    Args:
        img_tensor -- this is output of CNN
        target -- caption vectors of dim (units, max_length) where units is num GRUs and max_length is size of caption with most tokens
    """
    loss = 0

    hidden = decoder.reset_state(batch_size=target.shape[0])
    dec_input = tf.expand_dims([tokenizer.word_index['<start>']] *
                               target.shape[0], 1)

    with tf.GradientTape() as tape:

        features = encoder(img_tensor)
        for i in range(1, target.shape[1]):
            predictions, hidden, _ = decoder(dec_input, features, hidden)
            loss += loss_function(target[:, i], predictions, loss_object)
            dec_input = tf.expand_dims(
                target[:, i], 1
            )  # take the ith word in target not pred i.e. teacher forcing method

    total_loss = (loss / int(target.shape[1]))

    if not validation:
        trainable_variables = encoder.trainable_variables + decoder.trainable_variables
        gradients = tape.gradient(loss, trainable_variables)
        optimizer.apply_gradients(zip(gradients, trainable_variables))

        return loss, total_loss

    else:
        return loss, total_loss, predictions
Exemplo n.º 5
0
def train_step(img_tensor, target):
    loss = 0
    # initializing the hidden state for each batch
    # because the captions are not related from image to image.

    hidden = decoder.reset_state(
        batch_size=target.shape[0])  # shape: (bs, 512)
    dec_input = tf.expand_dims([tokenizer.word_index['<start>']] *
                               target.shape[0], 1)  # shape: (bs,1)

    with tf.GradientTape() as tape:
        features = encoder(img_tensor)  # features shape: (bs, 64, 256)

        for i in range(1, target.shape[1]):
            # passing the features through the decoder.
            predictions, hidden, _ = decoder(dec_input, features, hidden)

            loss += loss_function(target[:, i], predictions)

            # using teacher forcing
            dec_input = tf.expand_dims(target[:, i], 1)

    total_loss = (loss / int(target.shape[1]))

    trainable_variables = encoder.trainable_variables + decoder.trainable_variables

    gradients = tape.gradient(loss, trainable_variables)

    optimizer.apply_gradients(zip(gradients, trainable_variables))

    return loss, total_loss
Exemplo n.º 6
0
def train_loop(model, dataloader, optimizer, device, dataset_len, scale):
    model.train()

    running_loss = 0.0
    running_corrects = 0

    for bi, inputs in enumerate(dataloader):
        optimizer.zero_grad()

        labels, post_features, lens = inputs

        labels = labels.to(device)
        post_features = post_features.to(device)

        output_normal, output_adv = model(post_features, lens, labels)

        _, preds = torch.max(output_normal, 1)

        loss = loss_function(output_normal, output_adv, labels, scale)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = running_corrects.double() / dataset_len

    return epoch_loss, epoch_acc
Exemplo n.º 7
0
def valid_step(inp, tar):
    combined_mask = create_masks(inp)
    predictions = lang_model.model(inp, combined_mask, False)
    loss = loss_function(tar, predictions)

    tar_weight = tf.cast(tf.logical_not(tf.math.equal(tar, 0)), tf.int32)
    valid_loss(loss)
    valid_acc(tar, predictions, sample_weight=tar_weight)
Exemplo n.º 8
0
def train_step(inp, tar):
    combined_mask = create_masks(inp)
    with tf.GradientTape() as tape:
        predictions = lang_model.model(inp, combined_mask, True)
        loss = loss_function(tar, predictions)
    gradients = tape.gradient(loss, lang_model.model.trainable_variables)
    opt.apply_gradients(zip(gradients, lang_model.model.trainable_variables))

    tar_weight = tf.cast(tf.logical_not(tf.math.equal(tar, 0)), tf.int32)
    train_loss(loss)
    train_acc(tar, predictions, sample_weight=tar_weight)
Exemplo n.º 9
0
def model_test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data  in enumerate(test):
            data = data.to(dev)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                         recon_batch.view(batch_size, 1, 20, 20)[:n]])
                save_image(comparison.cpu(),
                           'results/reconstruction_' + str(epoch) + '.png', nrow=n)
    test_loss /= len(test)
    print('====> Test set loss: {:.4f}'.format(test_loss))
Exemplo n.º 10
0
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(data_loader):
        x = Variable(batch[0])
        x = batch_rgb_to_bgr(x)
        if cuda:
            x = x.cuda()
        
        y_hat = model(x)
        xc = Variable(x.data, volatile=True)
        optimizer.zero_grad()
        loss = loss_function(args.content_weight, args.style_weight, xc, xs, y_hat, cuda)
        loss.backward()
        optimizer.step()

        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(data_loader), loss.data[0]))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))
Exemplo n.º 11
0
def model_train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train):
        data = data.to(dev)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.0f}'.format(
                epoch, batch_idx * len(data), len(train),
                100. * batch_idx / len(train),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train)))
Exemplo n.º 12
0
def train(args, model, optimizer, train_loader, device, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
Exemplo n.º 13
0
def eval_loop(model,
              expt_type,
              dataloader,
              device,
              dataset_len,
              loss_type,
              scale=1):
    model.eval()

    running_loss = 0.0
    running_corrects = 0

    fin_targets = []
    fin_outputs = []

    fin_conf = []

    for bi, inputs in enumerate(dataloader):
        labels, tweet_features, lens = inputs

        labels = labels.to(device)
        tweet_features = tweet_features.to(device)

        with torch.no_grad():
            output = model(tweet_features, lens, labels)

        _, preds = torch.max(output, 1)

        loss = loss_function(output, labels, loss_type, expt_type, scale)

        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels.data)

        fin_conf.append(output.cpu().detach().numpy())

        fin_targets.append(labels.cpu().detach().numpy())
        fin_outputs.append(preds.cpu().detach().numpy())

    epoch_loss = running_loss / len(dataloader)
    epoch_accuracy = running_corrects.double() / dataset_len

    return epoch_loss, epoch_accuracy, np.hstack(fin_outputs), np.hstack(
        fin_targets), fin_conf
Exemplo n.º 14
0
def generator_loss():
    if request.method=='POST':
        req_data = request.get_json()
        in_generator_loss = req_data['loss']
        loss_fn_state = req_data['state']
        generator_loss_function = loss_function(in_generator_loss)
        if (generator_loss_function != None):
            necessary_elements['generator_loss_function'] = generator_loss_function
            return jsonify(
                response= 'generator loss function ' + str(loss_fn_state) + ' successfully',
                loss_function= str(generator_loss_function),
                mimetype='application/json'
            )
        else:
            return jsonify(
                response= 'request failed',
                status=500,
                hint= 'try BCE or MSE or refer to documentation'
                )
Exemplo n.º 15
0
def discriminator_loss():
    if request.method=='POST':
        req_data = request.get_json()
        #loss function:
        in_discriminator_loss = req_data['loss']
        #whether the loss is used for the first time or updated:
        loss_fn_state = req_data['state']
        discriminator_loss_function = loss_function(in_discriminator_loss)
        if (discriminator_loss_function != None):
            necessary_elements['discriminator_loss_function'] = discriminator_loss_function
            return jsonify(
                response= 'discriminator loss function ' + str(loss_fn_state) + ' successfully',
                loss_function= str(discriminator_loss_function),
                mimetype='application/json'
            )
        else:
            return jsonify(
                response= 'request failed',
                status=500,
                hint= 'try BCE or MSE or refer to documentation'
                )
def validation_step(img_tensor, target):
    loss = 0

    # initializing the hidden state for each batch.
    # because the captions are not related from image to image.

    hidden = decoder.reset_state(
        batch_size=target.shape[0])  # shape: (bs, 512)
    dec_input = tf.expand_dims([tokenizer.word_index['<start>']] *
                               target.shape[0], 1)  # shape: (bs,1)

    features = encoder(img_tensor)  # features shape: (bs, 64, 256)

    for i in range(1, target.shape[1]):
        # passing the features through the decoder.
        predictions, hidden, _ = decoder(dec_input, features, hidden)
        loss += loss_function(target[:, i], predictions)
        # using teacher forcing
        dec_input = tf.expand_dims(target[:, i], 1)

    total_loss = (loss / int(target.shape[1]))

    return loss, total_loss
Exemplo n.º 17
0
    def _forward_pass(self, img_tensor, target):
        """Training step as tf.function to allow for gradient updates in tensorflow.

        Args:
            img_tensor -- this is output of CNN
            target -- caption vectors of dim (units, max_length) where units is num GRUs and max_length is size of caption with most tokens
        """
        loss = 0

        hidden = self.decoder.reset_state(batch_size=target.shape[0])
        dec_input = tf.expand_dims([self.tokenizer.word_index['<start>']] * target.shape[0], 1)

        features = self.encoder(img_tensor)
        result_ids = []
        for i in range(1, target.shape[1]):
            predictions, hidden, _ = self.decoder(dec_input, features, hidden)
            predicted_ids = tf.math.argmax(predictions, axis=1)
            result_ids.append(predicted_ids) 
            loss += loss_function(target[:, i], predictions, self.loss_object)
            dec_input = tf.expand_dims(predicted_ids, 1) # take the ith word in target not pred i.e. teacher forcing method

        total_loss = (loss / int(target.shape[1]))

        return loss, total_loss, result_ids
Exemplo n.º 18
0
def train():

    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()
        tb_step = 50
        save_step = math.ceil(FLAGS.num_examples_per_epoch_train /
                              FLAGS.batch_size)

        var = FLAGS.var

        images_train, labels_train, labels_objects = distorted_inputs_files(
            filename=FLAGS.training_file,
            data_dir="../data/retina/",
            batch_size=FLAGS.batch_size,
            var=var,
            half=FLAGS.training_half)
        print("Image_train shape: " + str(images_train.get_shape()))

        images_placeholder = tf.placeholder(
            tf.float32,
            shape=(FLAGS.batch_size,
                   int(FLAGS.image_height * FLAGS.resize_factor),
                   int(FLAGS.image_width * FLAGS.resize_factor),
                   FLAGS.num_channels))

        labels_placeholder = tf.placeholder(
            tf.float32,
            shape=(FLAGS.batch_size,
                   int(FLAGS.image_height * FLAGS.resize_factor),
                   int(FLAGS.image_width * FLAGS.resize_factor),
                   FLAGS.num_classes * FLAGS.num_objects))

        labels_objects_placeholder = tf.placeholder(tf.float32,
                                                    shape=(FLAGS.batch_size,
                                                           FLAGS.num_objects))

        bn_training_placeholder = tf.placeholder(tf.bool)
        keep_prob = tf.placeholder(tf.float32)  # dropout (keep probability)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits_map, logits_class = create_conv_net(
            images_placeholder,
            keep_prob,
            FLAGS.num_channels,
            FLAGS.num_classes * FLAGS.num_objects,
            summaries=False,
            two_sublayers=False,
            training=bn_training_placeholder,
            dropout=0.5)

        loss_ = loss_function(logits_map, logits_class, labels_placeholder,
                              labels_objects_placeholder)
        evaluation_ = evaluation(logits_map, labels_placeholder, True)
        precision_ = precision(logits_class, labels_objects_placeholder)

        train_op = train_model(loss_, global_step)

        saver = tf.train.Saver(max_to_keep=0)  # keep all checkpoints

        merged = tf.summary.merge_all()

        with tf.Session() as sess:
            print("Starting Training Loop")
            sess.run(tf.local_variables_initializer())
            print("Local Variable Initializer done...")
            sess.run(tf.global_variables_initializer())
            print("Global Variable Initializer done...")
            coord = tf.train.Coordinator()
            print("Train Coordinator done...")
            print("Starting Queue Runner")
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            for step in range(FLAGS.max_steps):
                start_time = time.time()
                images_s, labels_s, labels_objects_s = sess.run(
                    [images_train, labels_train, labels_objects])

                train_feed = {
                    images_placeholder: images_s,
                    labels_placeholder: labels_s,
                    labels_objects_placeholder: labels_objects_s,
                    keep_prob: 0.5,
                    bn_training_placeholder: True
                }

                if step % tb_step == 0:
                    summary, _, precision_value, evaluation_value, loss_value = sess.run(
                        [merged, train_op, precision_, evaluation_, loss_],
                        feed_dict=train_feed)

                else:
                    _, precision_value, evaluation_value, loss_value = sess.run(
                        [train_op, precision_, evaluation_, loss_],
                        feed_dict=train_feed)

                duration = time.time() - start_time

                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.6f, precision=%.6f, joint_accuarcy=%.6f (%.1f examples/sec; %.3f '
                    'sec/batch)')
                print_str_loss = format_str % (
                    datetime.now(), step, loss_value, precision_value,
                    np.mean(evaluation_value), examples_per_sec, sec_per_batch)
                print(print_str_loss)

                if step % save_step == 0:
                    save_path = saver.save(sess,
                                           FLAGS.train_dir + FLAGS.experiment +
                                           "/" + 'model.cpkt-' + str(step),
                                           write_meta_graph=False)
                    print('Model saved in file: ' + save_path)

            save_path = saver.save(sess,
                                   FLAGS.train_dir + FLAGS.experiment + "/" +
                                   'model.cpkt-' + str(step),
                                   write_meta_graph=False)

            print('Model saved in file: ' + save_path)
            coord.request_stop()
            coord.join(threads)
Exemplo n.º 19
0
    model = GCNModelVAE(128, 32, 64).cuda()
    optim = Adam(model.parameters(), lr=lr)

    for i in range(epoch):
        # train
        model.train()
        optim.zero_grad()
        prediction, mu, logvar = model(
            torch.Tensor(drugs).cuda(),
            torch.LongTensor(proteins).cuda(), adj_train.cuda())
        # print("affinity",affinity[train_rows, train_cols].shape)
        # print(torch.FloatTensor(affinity[train_rows, train_cols]).size())
        # print("prediction",prediction[train_rows, train_cols].size())
        # print(prediction[train_rows, train_cols])
        loss = loss_function(
            torch.FloatTensor(affinity[train_rows, train_cols]).cuda(),
            prediction[train_rows, train_cols], mu, logvar,
            len(drugs) + len(proteins))
        # print(float(loss))
        loss.backward()
        optim.step()
        if i % 10 == 0:
            print("epoch:", i, "===========")
            print("training loss:", float(loss))
            model.eval()
            with torch.no_grad():
                mse_val = mse(
                    affinity[valid_rows, valid_cols],
                    prediction[valid_rows, valid_cols].cpu().numpy().flatten())
                print("validation mse:", mse_val)

                mse_test = mse(
Exemplo n.º 20
0
        # prepare input for fusion
        input_to_fusion = torch.cat(
            [source, image_flow, image_gen,
             flow.permute(0, 3, 1, 2)], dim=1)

        # get fusion scores
        fusion_scores = torch.sigmoid(fusion(input_to_fusion))

        # final GAF synthesis
        image_gaf = fusion_scores * image_flow + (1 -
                                                  fusion_scores) * image_gen

        # loss
        loss_net1, discrim_loss = loss_function(image_gaf, target, discrim,
                                                gan_loss, feature_loss,
                                                l1_loss)

        # discriminator optimization
        discrim_loss.mean().backward()
        optim_d.step()

        net_loss = loss_net1
        net_loss.mean().backward()

        # main network optimization
        optim.step()

        loss_train += net_loss.mean().item()
        loss_train_d += discrim_loss.mean().item()
Exemplo n.º 21
0
def main(args):
    tdatetime = dt.now()
    train_date = tdatetime.strftime('%Y%m%d')
    train_log_file = open(
        os.path.join(args.save_dir, 'train_{}.txt'.format(train_date)), 'w')
    val_log_file = open(
        os.path.join(args.save_dir, 'val_{}.txt'.format(train_date)), 'w')

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    with torch.cuda.device(args.gpu_device_num):
        CASENet = Model(ResidualBlock, [3, 4, 23, 3], class_num)
        CASENet.cuda()

        train_loader = get_loader(img_root=args.train_image_dir,
                                  mask_root=args.train_mask_dir,
                                  json_path=args.train_json_path,
                                  pair_transform=pair_transform,
                                  input_transform=input_transform,
                                  target_transform=None,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

        val_loader = get_loader(img_root=args.val_image_dir,
                                mask_root=args.val_mask_dir,
                                json_path=args.val_json_path,
                                pair_transform=val_pair_transform,
                                input_transform=input_transform,
                                target_transform=None,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers)

        lr = args.learning_rate
        optimizer = torch.optim.SGD(CASENet.parameters(),
                                    lr=lr,
                                    momentum=0.9,
                                    weight_decay=0.0005)
        loss_latest = 0
        batch_batch_count = 0

        # Training
        for epoch in tqdm(range(args.epochs)):
            if args.batch_batch:
                """
					using mini-batch in mini-batch
				"""
                train_loss_total = 0
                train_prog = tqdm(enumerate(train_loader),
                                  total=len(train_loader))
                for i, (images, masks) in train_prog:

                    images = Variable(images).cuda()
                    masks = Variable(masks).cuda()

                    optimizer.zero_grad()

                    fused_output, side_output = CASENet(images)

                    # actually, in the edge detection, we need set the weight, witch is none edge pix rate.
                    loss_side = loss_function(side_output, masks)
                    loss_fuse = loss_function(fused_output, masks)
                    loss = loss_side + loss_fuse
                    if batch_batch_count < args.batch_batch_size:
                        batch_batch_count += 1
                        continue
                    else:
                        batch_batch_count = 0

                    train_loss_total += loss.data[0]
                    loss.data[0] /= args.batch_batch_size
                    loss.backward()
                    optimizer.step()

                    train_prog.set_description("batch loss : {:.5}".format(
                        loss.data[0]))

                torch.save(
                    CASENet.state_dict(),
                    args.save_dir + 'CASENet_param_{}.pkl'.format(epoch))

            else:
                """
					usual training
				"""
                train_loss_total = 0
                train_prog = tqdm(enumerate(train_loader),
                                  total=len(train_loader))
                for i, (images, masks) in train_prog:

                    images = Variable(images).cuda()
                    masks = Variable(masks).cuda()

                    optimizer.zero_grad()

                    fused_output, side_output = CASENet(images)

                    # actually, in the edge detection, we need set the weight, witch is none edge pix rate.
                    loss_side = loss_function(side_output, masks)
                    loss_fuse = loss_function(fused_output, masks)
                    loss = loss_side + loss_fuse
                    train_loss_total += loss.data[0]
                    loss.backward()
                    optimizer.step()

                    train_prog.set_description("batch loss : {:.5}".format(
                        loss.data[0]))

                torch.save(
                    CASENet.state_dict(),
                    args.save_dir + 'CASENet_param_{}.pkl'.format(epoch))

            # Decaying Learning Rate
            if (epoch + 1) % 30 == 0:
                lr /= 10
                optimizer = torch.optim.SGD(CASENet.parameters(),
                                            lr=lr,
                                            momentum=0.9)

            #print("train loss [epochs {0}/{1}]: {2}".format( epoch, args.epochs,train_loss_total))
            train_log_file.write("{}".format(train_loss_total))
            train_log_file.flush()

            val_prog = tqdm(enumerate(val_loader), total=len(val_loader))
            CASENet.eval()
            val_loss_total = 0

            for i, (images, masks) in val_prog:
                images = Variable(images).cuda()
                masks = Variable(masks).cuda()

                fused_output, side_output = CASENet(images)

                # actually, in the edge detection, we need set the weight, witch is none edge pix rate.
                loss_side = loss_function(side_output, masks)
                loss_fuse = loss_function(fused_output, masks)
                loss = loss_side + loss_fuse
                val_loss_total += loss.data[0]

                val_prog.set_description(
                    "validation batch loss : {:.5}".format(loss.data[0]))
                if i == 0:
                    predic = F.log_softmax(fused_output)
                    predic = predic[0]
                    _, ind = predic.sort(1)
                    ind = ind.cpu().data.numpy()
                    msk = masks.cpu().data.numpy()
                    ind = Image.fromarray(np.uint8(ind[-1]))
                    msk = Image.fromarray(np.uint8(msk[0]))
                    ind.save(args.save_dir +
                             "output_epoch{}.png".format(epoch))
                    msk.save(args.save_dir + "mask_epoch{}.png".format(epoch))

            #print("validation loss : {0}".format(val_loss_total))
            val_log_file.write("{}".format(val_loss_total))
            val_log_file.flush()
            CASENet.train()

        # Save the Model
        torch.save(CASENet.state_dict(),
                   'CASENet_{0}_fin.pkl'.format(args.epochs))
        log_file.close()
Exemplo n.º 22
0
best_loss = 1e9

for epoch in range(args.epochs):
	model.train() # Enable dropout (if have).

	start_time = time.time()
	train_loader.dataset.ng_sample()

	for user, item, label, noisy_or_not in train_loader:
		user = user.cuda()
		item = item.cuda()
		label = label.float().cuda()

		model.zero_grad()
		prediction = model(user, item)
		loss = loss_function(prediction, label, args.alpha)
		loss.backward()
		optimizer.step()

		if count % args.eval_freq == 0 and count != 0:
			print("epoch: {}, iter: {}, loss:{}".format(epoch, count, loss))
			best_loss = eval(model, valid_loader, best_loss, count)
			model.train()

		count += 1

print("############################## Training End. ##############################")
test_model = torch.load('{}{}_{}.pth'.format(model_path, args.model, args.alpha))
test_model.cuda()
test(test_model, test_data_pos, user_pos)
Exemplo n.º 23
0
def main():
    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    args = parse_args()
    distributed_args = accelerate.DistributedDataParallelKwargs(
        find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[distributed_args])
    device = accelerator.device
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        filename=f'xmc_{args.dataset}_{args.mode}_{args.log}.log',
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state)

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(
        logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    ch = logging.StreamHandler(sys.stdout)
    logger.addHandler(ch)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    logger.info(sent_trans.__file__)

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Load pretrained model and tokenizer
    if args.model_name_or_path == 'bert-base-uncased' or args.model_name_or_path == 'sentence-transformers/paraphrase-mpnet-base-v2':
        query_encoder = build_encoder(
            args.model_name_or_path,
            args.max_label_length,
            args.pooling_mode,
            args.proj_emb_dim,
        )
    else:
        query_encoder = sent_trans.SentenceTransformer(args.model_name_or_path)

    tokenizer = query_encoder._first_module().tokenizer

    block_encoder = query_encoder

    model = DualEncoderModel(query_encoder, block_encoder, args.mode)
    model = model.to(device)

    # the whole label set
    data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset',
                             args.dataset)
    all_labels = pd.read_json(os.path.join(data_path, 'lbl.json'), lines=True)
    label_list = list(all_labels.title)
    label_ids = list(all_labels.uid)
    label_data = SimpleDataset(label_list, transform=tokenizer.encode)

    # label dataloader for searching
    sampler = SequentialSampler(label_data)
    label_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 64)
    label_dataloader = DataLoader(label_data,
                                  sampler=sampler,
                                  batch_size=16,
                                  collate_fn=label_padding_func)

    # label dataloader for regularization
    reg_sampler = RandomSampler(label_data)
    reg_dataloader = DataLoader(label_data,
                                sampler=reg_sampler,
                                batch_size=4,
                                collate_fn=label_padding_func)

    if args.mode == 'ict':
        train_data = ICTXMCDataset(tokenizer=tokenizer, dataset=args.dataset)
    elif args.mode == 'self-train':
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode)
    elif args.mode == 'finetune-pair':
        train_path = os.path.join(data_path, 'trn.json')
        pos_pair = []
        with open(train_path) as fp:
            for i, line in enumerate(fp):
                inst = json.loads(line.strip())
                inst_id = inst['uid']
                for ind in inst['target_ind']:
                    pos_pair.append((inst_id, ind, i))
        dataset_size = len(pos_pair)
        indices = list(range(dataset_size))
        split = int(np.floor(args.ratio * dataset_size))
        np.random.shuffle(indices)
        train_indices = indices[:split]
        torch.distributed.broadcast_object_list(train_indices,
                                                src=0,
                                                group=None)
        sample_pairs = [pos_pair[i] for i in train_indices]
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode,
                                sample_pairs=sample_pairs)
    elif args.mode == 'finetune-label':
        label_index = []
        label_path = os.path.join(data_path, 'label_index.json')
        with open(label_path) as fp:
            for line in fp:
                label_index.append(json.loads(line.strip()))
        np.random.shuffle(label_index)
        sample_size = int(np.floor(args.ratio * len(label_index)))
        sample_label = label_index[:sample_size]
        torch.distributed.broadcast_object_list(sample_label,
                                                src=0,
                                                group=None)
        sample_pairs = []
        for i, label in enumerate(sample_label):
            ind = label['ind']
            for inst_id in label['instance']:
                sample_pairs.append((inst_id, ind, i))
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode,
                                sample_pairs=sample_pairs)

    train_sampler = RandomSampler(train_data)
    padding_func = lambda x: ICT_batchify(x, tokenizer.pad_token_id, 64, 288)
    train_dataloader = torch.utils.data.DataLoader(
        train_data,
        sampler=train_sampler,
        batch_size=args.per_device_train_batch_size,
        num_workers=4,
        pin_memory=False,
        collate_fn=padding_func)

    try:
        accelerator.print("load cache")
        all_instances = torch.load(
            os.path.join(data_path, 'all_passages_with_titles.json.cache.pt'))
        test_data = SimpleDataset(all_instances.values())
    except:
        all_instances = {}
        test_path = os.path.join(data_path, 'tst.json')
        if args.mode == 'ict':
            train_path = os.path.join(data_path, 'trn.json')
            train_instances = {}
            valid_passage_ids = train_data.valid_passage_ids
            with open(train_path) as fp:
                for line in fp:
                    inst = json.loads(line.strip())
                    train_instances[
                        inst['uid']] = inst['title'] + '\t' + inst['content']
            for inst_id in valid_passage_ids:
                all_instances[inst_id] = train_instances[inst_id]
        test_ids = []
        with open(test_path) as fp:
            for line in fp:
                inst = json.loads(line.strip())
                all_instances[
                    inst['uid']] = inst['title'] + '\t' + inst['content']
                test_ids.append(inst['uid'])
        simple_transform = lambda x: tokenizer.encode(
            x, max_length=288, truncation=True)
        test_data = SimpleDataset(list(all_instances.values()),
                                  transform=simple_transform)
        inst_num = len(test_data)

    sampler = SequentialSampler(test_data)
    sent_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 288)
    instance_dataloader = DataLoader(test_data,
                                     sampler=sampler,
                                     batch_size=128,
                                     collate_fn=sent_padding_func)

    # prepare pairs
    reader = csv.reader(open(os.path.join(data_path, 'all_pairs.txt'),
                             encoding="utf-8"),
                        delimiter=" ")
    qrels = {}
    for id, row in enumerate(reader):
        query_id, corpus_id, score = row[0], row[1], int(row[2])
        if query_id not in qrels:
            qrels[query_id] = {corpus_id: score}
        else:
            qrels[query_id][corpus_id] = score

    logging.info("| |ICT_dataset|={} pairs.".format(len(train_data)))

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=1e-8)

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, label_dataloader, reg_dataloader, instance_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, label_dataloader, reg_dataloader,
        instance_dataloader)

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    # args.max_train_steps = 100000
    args.num_train_epochs = math.ceil(args.max_train_steps /
                                      num_update_steps_per_epoch)
    args.num_warmup_steps = int(0.1 * args.max_train_steps)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Train!
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_data)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(
        f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Learning Rate = {args.learning_rate}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps),
                        disable=not accelerator.is_local_main_process)
    completed_steps = 0
    from torch.cuda.amp import autocast
    scaler = torch.cuda.amp.GradScaler()
    cluster_result = eval_and_cluster(args, logger, completed_steps,
                                      accelerator.unwrap_model(model),
                                      label_dataloader, label_ids,
                                      instance_dataloader, inst_num, test_ids,
                                      qrels, accelerator)
    reg_iter = iter(reg_dataloader)
    trial_name = f"dim-{args.proj_emb_dim}-bs-{args.per_device_train_batch_size}-{args.dataset}-{args.log}-{args.mode}"
    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t for t in batch)
            label_tokens, inst_tokens, indices = batch
            if args.mode == 'ict':
                try:
                    reg_data = next(reg_iter)
                except StopIteration:
                    reg_iter = iter(reg_dataloader)
                    reg_data = next(reg_iter)

            if cluster_result is not None:
                pseudo_labels = cluster_result[indices]
            else:
                pseudo_labels = indices
            with autocast():
                if args.mode == 'ict':
                    label_emb, inst_emb, inst_emb_aug, reg_emb = model(
                        label_tokens, inst_tokens, reg_data)
                    loss, stats_dict = loss_function_reg(
                        label_emb, inst_emb, inst_emb_aug, reg_emb,
                        pseudo_labels, accelerator)
                else:
                    label_emb, inst_emb = model(label_tokens,
                                                inst_tokens,
                                                reg_data=None)
                    loss, stats_dict = loss_function(label_emb, inst_emb,
                                                     pseudo_labels,
                                                     accelerator)
                loss = loss / args.gradient_accumulation_steps

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            if step % args.gradient_accumulation_steps == 0 or step == len(
                    train_dataloader) - 1:
                scaler.step(optimizer)
                scaler.update()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

            if completed_steps % args.logging_steps == 0:
                if args.mode == 'ict':
                    logger.info(
                        "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e}  Contrast Loss {:.6e}  Reg Loss {:.6e}"
                        .format(
                            epoch,
                            args.num_train_epochs,
                            completed_steps,
                            args.max_train_steps,
                            stats_dict["loss"].item(),
                            stats_dict["contrast_loss"].item(),
                            stats_dict["reg_loss"].item(),
                        ))
                else:
                    logger.info(
                        "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e}"
                        .format(
                            epoch,
                            args.num_train_epochs,
                            completed_steps,
                            args.max_train_steps,
                            stats_dict["loss"].item(),
                        ))
            if completed_steps % args.eval_steps == 0:
                cluster_result = eval_and_cluster(
                    args, logger, completed_steps,
                    accelerator.unwrap_model(model), label_dataloader,
                    label_ids, instance_dataloader, inst_num, test_ids, qrels,
                    accelerator)
                unwrapped_model = accelerator.unwrap_model(model)

                unwrapped_model.label_encoder.save(
                    f"{args.output_dir}/{trial_name}/label_encoder")
                unwrapped_model.instance_encoder.save(
                    f"{args.output_dir}/{trial_name}/instance_encoder")

            if completed_steps >= args.max_train_steps:
                break
Exemplo n.º 24
0
def train():
    print('train:')

    IGCMAN.train()  # set the module in training  mode
    train_loss = 0.  # sum of train loss up to current batch
    train_correct = 0
    total = 0

    sum_prediction_label = torch.zeros(1, NUM_INGREDIENT) + 1e-6
    sum_correct_prediction_label = torch.zeros(1, NUM_INGREDIENT)
    sum_ground_truth_label = torch.zeros(1, NUM_INGREDIENT)

    for batch_num, (data, target, category) in enumerate(train_loader):
        if target.sum() == 0:
            continue
        target = target.index_select(0,
                                     torch.nonzero(target.sum(dim=1)).view(-1))
        data = data.index_select(0, torch.nonzero(target.sum(dim=1)).view(-1))

        if GPU_IN_USE:
            data, target = data.cuda(), target.cuda()
            category = category.cuda()
        data = torch.autograd.Variable(data)
        target = torch.autograd.Variable(target)
        category = torch.autograd.Variable(category)

        # -----forward-----
        optimizer.zero_grad()
        output, M, score_category = IGCMAN(data)
        # ---end forward---

        # ---calculate loss and backward---
        loss = loss_function(output,
                             target,
                             M,
                             score_category,
                             category,
                             add_constraint=True)
        loss.backward()
        optimizer.step()
        # ----------end backward-----------

        train_loss += loss

        prediction = torch.topk(
            F.softmax(output, dim=1), 10,
            dim=1)  # return the max value and the index tuple
        filter = prediction[0].eq(0.1) + prediction[0].gt(0.1)
        prediction_index = torch.mul(prediction[1] + 1,
                                     filter.type(torch.cuda.LongTensor))
        extend_eye_mat = torch.cat(
            (torch.zeros(1, NUM_INGREDIENT), torch.eye(NUM_INGREDIENT)), 0)
        prediction_label = extend_eye_mat[prediction_index.view(-1)].view(
            -1, 10, NUM_INGREDIENT).sum(dim=1)
        correct_prediction_label = (target.cpu().byte()
                                    & prediction_label.byte()).type(
                                        torch.FloatTensor)

        #count the sum of label vector
        sum_prediction_label += prediction_label.sum(dim=0)
        sum_correct_prediction_label += correct_prediction_label.sum(dim=0)
        sum_ground_truth_label += target.cpu().sum(dim=0)

        # # calculate  accuracy
        _, train_predict = torch.max(score_category, 1)
        total += BATCH_SIZE
        train_correct += torch.sum(train_predict == category.data)
        train_acc = float(train_correct) / total

        if batch_num % LOSS_OUTPUT_INTERVAL == 0:

            print(
                'train loss %.3f (batch %d) accuracy %0.3f' %
                (train_loss / (batch_num + 1), batch_num + 1), train_acc)

    # evaluation metrics
    o_p = torch.div(sum_correct_prediction_label.sum(),
                    sum_prediction_label.sum())
    o_r = torch.div(sum_correct_prediction_label.sum(),
                    sum_ground_truth_label.sum())
    of1 = torch.div(2 * o_p * o_r, o_p + o_r)
    c_p = (torch.div(sum_correct_prediction_label,
                     sum_prediction_label)).sum() / NUM_INGREDIENT
    c_r = (torch.div(sum_correct_prediction_label,
                     sum_ground_truth_label)).sum() / NUM_INGREDIENT
    cf1 = torch.div(2 * c_p * c_r, c_p + c_r)

    return c_p, c_r, cf1, o_p, o_r, of1
Exemplo n.º 25
0
def test():
    print('test:')
    IGCMAN.eval()  # set the module in evaluation mode
    test_loss = 0.  # sum of train loss up to current batch
    test_correct = 0
    total = 0

    sum_prediction_label = torch.zeros(1, NUM_INGREDIENT) + 1e-6
    sum_correct_prediction_label = torch.zeros(1, NUM_INGREDIENT)
    sum_ground_truth_label = torch.zeros(1, NUM_INGREDIENT)

    for batch_num, (data, target, category) in enumerate(test_loader):
        if target.sum() == 0:
            continue
        target = target.index_select(0,
                                     torch.nonzero(target.sum(dim=1)).view(-1))
        data = data.index_select(0, torch.nonzero(target.sum(dim=1)).view(-1))

        if GPU_IN_USE:
            data, target = data.cuda(), target.cuda()  # set up GPU Tensor
            category = category.cuda()

        # f_I = extract_features(data)
        output, M, score_category = IGCMAN(data)
        loss = loss_function(output,
                             target,
                             M,
                             score_category,
                             category,
                             add_constraint=True)

        test_loss += loss
        prediction = torch.topk(F.softmax(output, dim=1), 10, dim=1)
        filter = prediction[0].eq(0.1) + prediction[0].gt(0.1)
        prediction_index = torch.mul(prediction[1] + 1,
                                     filter.type(torch.cuda.LongTensor))
        extend_eye_mat = torch.cat(
            (torch.zeros(1, NUM_INGREDIENT), torch.eye(NUM_INGREDIENT)), 0)
        prediction_label = extend_eye_mat[prediction_index.view(-1)].view(
            -1, 10, NUM_INGREDIENT).sum(dim=1)
        correct_prediction_label = (target.cpu().byte()
                                    & prediction_label.byte()).type(
                                        torch.FloatTensor)

        #count the sum of label vector
        sum_prediction_label += prediction_label.sum(dim=0)
        sum_correct_prediction_label += correct_prediction_label.sum(dim=0)
        sum_ground_truth_label += target.cpu().sum(dim=0)
        # # calculate  accuracy
        _, test_predict = torch.max(score_category, 1)
        total += BATCH_SIZE
        test_correct += torch.sum(test_predict == category.data)
        test_acc = float(test_correct) / total

        if batch_num % LOSS_OUTPUT_INTERVAL == 0:
            print('Test loss %.3f (batch %d) accuracy %0.3f' %
                  (test_loss / (batch_num + 1), batch_num + 1, test_acc))

    # evaluation metrics
    o_p = torch.div(sum_correct_prediction_label.sum(),
                    sum_prediction_label.sum())
    o_r = torch.div(sum_correct_prediction_label.sum(),
                    sum_ground_truth_label.sum())
    of1 = torch.div(2 * o_p * o_r, o_p + o_r)
    c_p = (torch.div(sum_correct_prediction_label,
                     sum_prediction_label)).sum() / NUM_INGREDIENT
    c_r = (torch.div(sum_correct_prediction_label,
                     sum_ground_truth_label)).sum() / NUM_INGREDIENT
    cf1 = torch.div(2 * c_p * c_r, c_p + c_r)

    return c_p, c_r, cf1, o_p, o_r, of1
Exemplo n.º 26
0
best_loss = 1e9

for epoch in range(args.epochs):
    model.train()  # Enable dropout (if have).

    start_time = time.time()
    train_loader.dataset.ng_sample()

    for user, item, label, noisy_or_not in train_loader:
        user = user.cuda()
        item = item.cuda()
        label = label.float().cuda()

        model.zero_grad()
        prediction = model(user, item)
        loss = loss_function(prediction, label, drop_rate_schedule(count))
        loss.backward()
        optimizer.step()

        if count % args.eval_freq == 0 and count != 0:
            print("epoch: {}, iter: {}, loss:{}".format(epoch, count, loss))
            best_loss = eval(model, valid_loader, best_loss, count)
            model.train()

        count += 1

print(
    "############################## Training End. ##############################"
)
test_model = torch.load('{}{}_{}-{}.pth'.format(model_path, args.model,
                                                args.drop_rate,
Exemplo n.º 27
0
        optim_d.zero_grad()

        source = data['source'].float().cuda()
        target = data['target'].float().cuda()
        vec = data['vec'].float().cuda()
        
        image_flow, flow, _  = af_plus(source, vec)
        
        # apply the scale space
        source_pad = F.pad(source, (1, 1, 1, 1))  
        gauss_weight = gauss_update_method(sigma)
        source_blurred = gauss_conv(source_pad, weight=gauss_weight, groups=3)

        image_flow_blurred = apply_warp(source_blurred, flow, grid)
        
        loss_net1, discrim_loss = loss_function(image_flow, target, discrim, gan_loss, feature_loss,
                                             l1_loss, need_gan_loss=True, image_out_blurred=image_flow_blurred)
        
        # discriminator optimization
        discrim_loss.mean().backward()
        optim_d.step()

        blur_regularization = 1e-3*blur_regularization.cuda()
        
        net_loss = loss_net1 + blur_regularization
        net_loss.mean().backward()
        
        # main network optimization
        optim.step()
        
        loss_train += net_loss.mean().item()
        loss_train_d += discrim_loss.mean().item()
Exemplo n.º 28
0
def train(args, images, fids, pids, max_fid_len, log):
    '''
    Creation model and training neural network

    :param args: all stored arguments
    :param images: prepared images for training
    :param fids: figure id (relative paths from image_root to images)
    :param pids: person id (or car id) for all images
    :param log: log file, where logs from training are stored
    :return: saved files (checkpoints, train log file)
    '''
    ###################################################################################################################
    # CREATE MODEL
    ###################################################################################################################
    # Create the model and an embedding head.

    model = import_module('nets.resnet_v1_50')
    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    drops = {}
    if args.dropout is not None:
        drops = getDropoutProbs(args.dropout)
    b4_layers = None
    try:
        b4_layers = int(args.b4_layers)
        if b4_layers not in [1, 2, 3]: raise ValueError()
    except:
        ValueError("Argument exception: b4_layers has to be in [1, 2, 3]")

    endpoints, body_prefix = model.endpoints(images,
                                             b4_layers,
                                             drops,
                                             is_training=True,
                                             resnet_stride=int(
                                                 args.resnet_stride))
    endpoints['emb'] = endpoints['emb_raw'] = slim.fully_connected(
        endpoints['model_output'],
        args.embedding_dim,
        activation_fn=None,
        weights_initializer=tf.orthogonal_initializer(),
        scope='emb')

    step_pl = tf.placeholder(dtype=tf.float32)

    features = endpoints['emb']

    # Create the loss in two steps:
    # 1. Compute all pairwise distances according to the specified metric.
    # 2. For each anchor along the first dimension, compute its loss.
    dists = loss.cdist(features, features, metric=args.metric)
    losses, train_top1, prec_at_k, _, probe_neg_dists, pos_dists, neg_dists = loss.loss_function(
        dists,
        pids, [args.alpha1, args.alpha2, args.alpha3],
        batch_precision_at_k=args.batch_k - 1)

    # Count the number of active entries, and compute the total batch loss.
    num_active = tf.reduce_sum(tf.cast(tf.greater(losses, 1e-5), tf.float32))
    loss_mean = tf.reduce_mean(losses)

    # Some logging for tensorboard.
    tf.summary.histogram('loss_distribution', losses)
    tf.summary.scalar('loss', loss_mean)
    tf.summary.scalar('batch_top1', train_top1)
    tf.summary.scalar('batch_prec_at_{}'.format(args.batch_k - 1), prec_at_k)
    tf.summary.scalar('active_count', num_active)
    tf.summary.scalar('embedding_pos_dists', tf.reduce_mean(pos_dists))
    tf.summary.scalar('embedding_probe_neg_dists',
                      tf.reduce_mean(probe_neg_dists))
    tf.summary.scalar('embedding_neg_dists', tf.reduce_mean(neg_dists))
    tf.summary.histogram('embedding_dists', dists)
    tf.summary.histogram('embedding_pos_dists', pos_dists)
    tf.summary.histogram('embedding_probe_neg_dists', probe_neg_dists)
    tf.summary.histogram('embedding_neg_dists', neg_dists)
    tf.summary.histogram('embedding_lengths',
                         tf.norm(endpoints['emb_raw'], axis=1))

    # Create the mem-mapped arrays in which we'll log all training detail in
    # addition to tensorboard, because tensorboard is annoying for detailed
    # inspection and actually discards data in histogram summaries.
    batch_size = args.batch_p * args.batch_k
    if args.detailed_logs:
        log_embs = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'embeddings'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size, args.embedding_dim))
        log_loss = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'losses'),
            dtype=np.float32,
            shape=(args.train_iterations, batch_size))
        log_fids = lb.create_or_resize_dat(
            os.path.join(args.experiment_root, 'fids'),
            dtype='S' + str(max_fid_len),
            shape=(args.train_iterations, batch_size))

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    # Define the optimizer and the learning-rate schedule.
    # Unfortunately, we get NaNs if we don't handle no-decay separately.
    global_step = tf.Variable(0, name='global_step', trainable=False)

    if args.sgdr:
        learning_rate = tf.train.cosine_decay_restarts(
            learning_rate=args.learning_rate,
            global_step=global_step,
            first_decay_steps=4000,
            t_mul=1.5)
    else:
        if 0 <= args.decay_start_iteration < args.train_iterations:
            learning_rate = tf.train.exponential_decay(
                args.learning_rate,
                tf.maximum(0, global_step - args.decay_start_iteration),
                args.train_iterations - args.decay_start_iteration,
                float(args.lr_decay))
        else:
            learning_rate = args.learning_rate
    tf.summary.scalar('learning_rate', learning_rate)
    optimizer = tf.train.AdamOptimizer(tf.convert_to_tensor(learning_rate))

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        train_op = optimizer.minimize(loss_mean, global_step=global_step)

    # Define a saver for the complete model.
    checkpoint_saver = tf.train.Saver(max_to_keep=0)
    with tf.Session() as sess:
        if args.resume:
            # In case we're resuming, simply load the full checkpoint to init.
            last_checkpoint = tf.train.latest_checkpoint(args.experiment_root)
            log.info('Restoring from checkpoint: {}'.format(last_checkpoint))
            checkpoint_saver.restore(sess, last_checkpoint)
        else:
            # But if we're starting from scratch, we may need to load some
            # variables from the pre-trained weights, and random init others.
            sess.run(tf.global_variables_initializer())
            if args.initial_checkpoint is not None:
                saver = tf.train.Saver(model_variables)
                saver.restore(sess, args.initial_checkpoint)

            # In any case, we also store this initialization as a checkpoint,
            # such that we could run exactly reproduceable experiments.
            checkpoint_saver.save(sess,
                                  os.path.join(args.experiment_root,
                                               'checkpoint'),
                                  global_step=0)

        merged_summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.experiment_root,
                                               sess.graph)

        start_step = sess.run(global_step)
        step = start_step
        log.info('Starting training from iteration {}.'.format(start_step))

        ###################################################################################################################
        # TRAINING
        ###################################################################################################################
        # Finally, here comes the main-loop. This `Uninterrupt` is a handy
        # utility such that an iteration still finishes on Ctrl+C and we can
        # stop the training cleanly.
        with lb.Uninterrupt(sigs=[SIGINT, SIGTERM], verbose=True) as u:
            for i in range(start_step, args.train_iterations):
                # Compute gradients, update weights, store logs!
                start_time = time.time()
                _, summary, step, b_prec_at_k, b_embs, b_loss, b_fids = \
                    sess.run([train_op, merged_summary, global_step,
                              prec_at_k, features, losses, fids], feed_dict={step_pl: step})
                elapsed_time = time.time() - start_time

                # Compute the iteration speed and add it to the summary.
                # We did observe some weird spikes that we couldn't track down.
                summary2 = tf.Summary()
                summary2.value.add(tag='secs_per_iter',
                                   simple_value=elapsed_time)
                summary_writer.add_summary(summary2, step)
                summary_writer.add_summary(summary, step)

                if args.detailed_logs:
                    log_embs[i], log_loss[i], log_fids[
                        i] = b_embs, b_loss, b_fids

                # Do a huge print out of the current progress. Maybe steal from here.
                seconds_todo = (args.train_iterations - step) * elapsed_time
                log.info(
                    'iter:{:6d}, loss min|avg|max: {:.3f}|{:.3f}|{:6.3f}, '
                    'batch-p@{}: {:.2%}, ETA: {} ({:.2f}s/it), lr={:.4g}'.
                    format(step, float(np.min(b_loss)), float(np.mean(b_loss)),
                           float(np.max(b_loss)), args.batch_k - 1,
                           float(b_prec_at_k),
                           timedelta(seconds=int(seconds_todo)), elapsed_time,
                           sess.run(optimizer._lr)))
                sys.stdout.flush()
                sys.stderr.flush()

                # Save a checkpoint of training every so often.
                if (args.checkpoint_frequency > 0
                        and step % args.checkpoint_frequency == 0):
                    checkpoint_saver.save(sess,
                                          os.path.join(args.experiment_root,
                                                       'checkpoint'),
                                          global_step=step)

                # Stop the main-loop at the end of the step, if requested.
                if u.interrupted:
                    log.info("Interrupted on request!")
                    break

        # Store one final checkpoint. This might be redundant, but it is crucial
        # in case intermediate storing was disabled and it saves a checkpoint
        # when the process was interrupted.
        checkpoint_saver.save(sess,
                              os.path.join(args.experiment_root, 'checkpoint'),
                              global_step=step)