Example #1
0
def prediction(img1, img2, label, weight):
    print("weight")

    best_metric = 0
    train_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    val_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    test_transform_det = trans.Compose([
        trans.Scale((960, 960)),
    ])
    model = SiamUNetU(in_ch=3)
    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model.cuda()
    # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(weight).items()})
    # model.load_state_dict(torch.load(weight))
    checkpoint = torch.load(weight)
    model.load_state_dict(checkpoint['state_dict'])

    # test_data = my_dataset.Dataset(cfg.TEST_DATA_PATH, '',cfg.TEST_TXT_PATH, 'test', transform=True, transform_med=test_transform_det)
    test_data = my_dataset.Dataset(cfg.VAL_DATA_PATH,
                                   cfg.VAL_LABEL_PATH,
                                   cfg.VAL_TXT_PATH,
                                   'val',
                                   transform=True,
                                   transform_med=test_transform_det)
    test_dataloader = DataLoader(test_data,
                                 batch_size=cfg.TEST_BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=8,
                                 pin_memory=True)
    crop = 0

    rows = 12
    cols = 12
    i = 0
    for batch_idx, val_batch in enumerate(test_dataloader):
        model.eval()
        batch_x1, batch_x2, _, filename, h, w = val_batch
        filename = filename[0].split('/')[-1].replace('image',
                                                      'mask_2017').replace(
                                                          '.png', '.tif')
        if crop:
            outputs = np.zeros((cfg.TEST_BATCH_SIZE, 1, 960, 960))

            while (i + w // rows <= w):
                j = 0
                while (j + h // cols <= h):
                    batch_x1_ij = batch_x1[batch_idx, :, i:i + w // rows,
                                           j:j + h // cols]
                    batch_x2_ij = batch_x2[batch_idx, :, i:i + w // rows,
                                           j:j + h // cols]
                    # batch_y_ij = batch_y[batch_idx,: , i:i + w // rows, j:j + h // cols]
                    batch_x1_ij = np.expand_dims(batch_x1_ij, axis=0)
                    batch_x2_ij = np.expand_dims(batch_x2_ij, axis=0)
                    batch_x1_ij, batch_x2_ij = Variable(
                        torch.from_numpy(batch_x1_ij)).cuda(), Variable(
                            torch.from_numpy(batch_x2_ij)).cuda()
                    with torch.no_grad():
                        output = model(batch_x1_ij, batch_x2_ij)
                    output_w, output_h = output.shape[-2:]
                    output = torch.sigmoid(output).view(output_w, output_h, -1)

                    output = output.data.cpu().numpy()  # .resize([80, 80, 1])
                    output = np.where(output > cfg.THRESH, 255, 0)
                    outputs[batch_idx, :, i:i + w // rows,
                            j:j + h // cols] = output

                    j += h // cols
                i += w // rows

            print(batch_idx)

            if not os.path.exists('./change'):
                os.mkdir('./change')
            print('./change/{}'.format(filename))
            cv2.imwrite('./change/crop_{}'.format(filename), outputs[batch_idx,
                                                                     0, :, :])
        else:
            batch_x1, batch_x2 = Variable(batch_x1).cuda(), Variable(
                batch_x2).cuda()
            with torch.no_grad():
                output = model(batch_x1, batch_x2)
            output_w, output_h = output.shape[-2:]
            output = torch.sigmoid(output).view(output_w, output_h, -1)
            output = output.data.cpu().numpy()  # .resize([80, 80, 1])
            output = np.where(output > cfg.THRESH, 255, 0)
            # output_final=cv2.merge(output)
            if not os.path.exists('./change'):
                os.mkdir('./change')

            print('./change/{}'.format(filename))
            cv2.imwrite('./change/{}'.format(filename), output)
Example #2
0
def prediction(weight):
    print("weight")

    best_metric = 0
    train_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    val_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])

    test_transform_det = trans.Compose([
        trans.Scale(cfg.TEST_TRANSFROM_SCALES),
    ])
    model = SiamUNet()
    # model=torch.nn.DataParallel(model)

    if torch.cuda.is_available():
        model.cuda()
        print('gpu')

    # model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(weight).items()})
    # model.load_state_dict(torch.load(weight))
    checkpoint = torch.load(weight)
    model.load_state_dict(checkpoint['state_dict'])
    test_data = my_dataset.Dataset(cfg.TEST_DATA_PATH,
                                   cfg.TEST_LABEL_PATH,
                                   cfg.TEST_TXT_PATH,
                                   'val',
                                   transform=True,
                                   transform_med=test_transform_det)
    test_dataloader = DataLoader(test_data,
                                 batch_size=cfg.TEST_BATCH_SIZE,
                                 shuffle=False,
                                 num_workers=8,
                                 pin_memory=True)
    crop = 0

    rows = 12
    cols = 12
    i = 0
    for batch_idx, val_batch in enumerate(test_dataloader):
        model.eval()

        batch_x1, batch_x2, mask, im_name, h, w = val_batch
        print('mask_type{}'.format(mask.type))

        with torch.no_grad():
            batch_x1, batch_x2 = Variable((batch_x1)).cuda(), Variable(
                ((batch_x2))).cuda()

            try:
                print('try')
                output = model(batch_x1, batch_x2)
                del batch_x1, batch_x2
            except RuntimeError as exception:
                if 'out of memory' in str(exception):
                    print('WARNING: out of memory')
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    print('exception')
                    raise exception
        # print(output)
        output_w, output_h = output.shape[-2:]
        output = torch.sigmoid(output).view(output_w, output_h, -1)
        # print(output)
        output = output.data.cpu().numpy()  # .resize([80, 80, 1])
        output = np.where(output > cfg.THRESH, 255, 0)
        # print(output)
        # have no mask so can not eval_cal
        # precision,recall,F1=eval_cal(output,mask)
        # print('precision:{}\nrecall:{}\nF1:{}'.format(precision,recall,F1))

        print(im_name)
        im_n = im_name[0].split('/')[1].split('.')[0].split('_')
        im__path = 'final_result/weight50_dmc/mask_2017_2018_960_960_' + im_n[
            4] + '.tif'

        # im__path = 'weitht50_tif.tif'
        im_data = np.squeeze(output)
        print(im_data.shape)
        im_data = np.array([im_data])
        print(im_data.shape)
        im_geotrans = (0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
        im_proj = ''
        im_width = 960
        im_height = 960
        im_bands = 1
        datatype = gdal.GDT_Byte
        driver = gdal.GetDriverByName("GTiff")
        dataset = driver.Create(im__path, im_width, im_height, im_bands,
                                datatype)
        if dataset != None:
            print("----{}".format(im__path))
            dataset.SetGeoTransform(im_geotrans)
            dataset.SetProjection(im_proj)
        for i in range(im_bands):
            dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

        del dataset
Example #3
0
                datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S"),
                ",".join(
                    ("{}={}".format(
                        re.sub("(.)[^_]*_?", r"\1", key), value) \
                            for key, value in sorted(vars(args).items()) \
                                if not '/' in str(value) \
                                and not 'threads' in key
                                and not 'logdir' in key
                    )
                )
            )

    print("The logdir is: {}".format(args.logdir))

    # Load the data
    train_set = data.Dataset(args.train_set, args.vocab, shuffle_batches=True)
    valid_set = data.Dataset(args.valid_set, args.vocab, shuffle_batches=False)

    # Construct the network
    network = Network(threads=args.threads)
    network.construct(args, train_set.num_tokens)

    # Train, batches
    print("Training started.")
    for i in range(args.epochs):
        while not train_set.epoch_finished():
            batch = train_set.next_batch(args.batch_size)
            network.train_batch(batch)

            # Saving embeddings
            #embeddings = network.embeddings()
Example #4
0
def train(params):
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        # calculate the learning rate schedule
        learning_rate = get_learning_rate(params.learning_rate,
                                          params.hidden_size,
                                          params.learning_rate_warmup_steps,
                                          global_step)

        optimizer = tf.contrib.opt.LazyAdamOptimizer(
            learning_rate,
            beta1=params.optimizer_adam_beta1,
            beta2=params.optimizer_adam_beta2,
            epsilon=params.optimizer_adam_epsilon)

        # get src,tgt sentence for each model tower
        my_dataset = dataset.Dataset(params)
        # src, tgt = my_dataset.train_input_fn(params)
        # batch_queue = tf.contrib.slim.prefetch_queue.prefetch_queue(
        #     [src, tgt], capacity=2 * flags_obj.num_gpus
        # )
        train_iterator = my_dataset.train_input_fn(params)
        valid_iterator = my_dataset.eval_input_fn(params)

        tower_grads = []
        g_tower_grads = []
        model = transformer_5.Transformer(params, is_train=True)
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            #tf.logging.info(tf.get_variable_scope())
            for i in xrange(flags_obj.num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope:
                        tf.logging.info("Build graph on gpu:{}".format(i))
                        loss, g_loss, rewards_mb = gan_tower_loss(
                            scope, model, train_iterator)
                        # Reuse variables for the next tower.
                        # tf.get_variable_scope().reuse_variables()
                        # Retain the summaries from the final tower.
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)

                        grads = optimizer.compute_gradients(loss)
                        g_grads = optimizer.compute_gradients(g_loss)
                        #for var, grad in grads:
                        #    tf.logging.info(var)
                        tf.logging.info(
                            "total trainable variables number: {}".format(
                                len(grads)))
                        tower_grads.append(grads)
                        g_tower_grads.append(g_grads)

                    if i == 0 and valid_iterator:
                        #with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope:
                        # valid_loss_op = tower_loss(scope, valid_iterator)
                        #val_pred, val_target = evaluation(valid_iterator)
                        val_loss_op, val_logits_op, val_tgt_op = evaluation(
                            model, valid_iterator)
                        summaries.append(
                            tf.summary.scalar("val_loss", val_loss_op))

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        if len(tower_grads) > 1:
            grads = average_gradients(tower_grads)
            g_grads = average_gradients(g_tower_grads)
        else:
            grads = tower_grads[0]
            g_grads = g_tower_grads[0]

        # Add a summary to track the learning rate.
        summaries.append(tf.summary.scalar('learning_rate', learning_rate))

        # Add histograms for gradients.
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = optimizer.apply_gradients(grads,
                                                      global_step=global_step)
        g_apply_gradient_op = optimizer.apply_gradients(
            g_grads, global_step=global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))

        # Track the moving averages of all trainable variables.
        #variable_averages = tf.train.ExponentialMovingAverage(
        #    MOVING_AVERAGE_DECAY, global_step)
        #variables_averages_op = variable_averages.apply(tf.trainable_variables())

        # Group all updates to into a single train op.
        # train_op = tf.group(apply_gradient_op, variables_averages_op)
        train_op = tf.group(apply_gradient_op, g_apply_gradient_op)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=20)

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Start running operations on the Graph. allow_soft_placement must be set to
        # True to build towers on GPU, as some of the ops do not have GPU
        # implementations.
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess_config.allow_soft_placement = True

        with tf.Session(config=sess_config) as sess:
            sess.run(init)
            sess.run(tf.local_variables_initializer())

            sess.run(train_iterator.initializer)

            #ckpt = tf.train.latest_checkpoint(flags_obj.pretrain_dir)
            ckpt = tf.train.latest_checkpoint(flags_obj.model_dir)
            tf.logging.info("ckpt {}".format(ckpt))
            if ckpt and tf.train.checkpoint_exists(ckpt):
                tf.logging.info(
                    "Reloading model parameters..from {}".format(ckpt))
                saver.restore(sess, ckpt)
            else:
                tf.logging.info("Create a new model...{}".format(
                    flags_obj.pretrain_dir))

            # Start the queue runners.
            tf.train.start_queue_runners(sess=sess)
            summary_writer = tf.summary.FileWriter(flags_obj.model_dir,
                                                   sess.graph)

            best_bleu = 0.0
            for step in xrange(flags_obj.train_steps):
                start_time = time.time()
                _, loss_value, g_loss_value, rewards_mb_value, baseline_value, total_rewards_value = sess.run(
                    [
                        train_op, loss, g_loss, rewards_mb, model.baseline,
                        model.total_rewards
                    ])
                tf.logging.info(
                    "step = {}, step_g_loss = {:.4f}, step_loss = {:.4f}".
                    format(step, g_loss_value, loss_value))
                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step % 100 == 0:
                    num_examples_per_step = flags_obj.batch_size * flags_obj.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = duration / flags_obj.num_gpus

                    tf.logging.info(
                        "step = {}, step_g_loss = {:.4f}, step_loss = {:.4f}, reward_mb = {}, baseline = {}, total_rewards = {}"
                        .format(step, g_loss_value, loss_value,
                                rewards_mb_value[:5], baseline_value[:5],
                                total_rewards_value[:5]))

                if step % 100 == 0:
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)

                if step % flags_obj.steps_between_evals == 0:
                    sess.run(valid_iterator.initializer)
                    tf.logging.info(
                        "-------------------- Validation step ...{} -------------------------- ----------"
                        .format(step))
                    total_bleu = 0.0
                    total_size = 0
                    total_loss = 0.0
                    while True:
                        try:
                            val_loss, val_logit, val_tgt = sess.run(
                                [val_loss_op, val_logits_op, val_tgt_op])
                            val_pred = np.argmax(val_logit, axis=-1)
                            val_bleu = metrics.compute_bleu(val_tgt, val_pred)
                            batch_size = val_pred.shape[0]
                            total_bleu += val_bleu * batch_size
                            total_loss += val_loss * batch_size
                            total_size += batch_size
                            tf.logging.info(
                                "pairs shape {}, {}, step_bleu: {:.5f}, step_loss: {:.4f}"
                                .format(val_pred.shape, val_tgt.shape,
                                        val_bleu, val_loss))
                        except tf.errors.OutOfRangeError:
                            pred_string = array_to_string(val_pred[-1])
                            tgt_string = array_to_string(val_tgt[-1])
                            tf.logging.info(
                                "prediction:\n{}".format(pred_string))
                            tf.logging.info("target:\n{}".format(tgt_string))
                            tf.logging.info(
                                "Finished going through the valid dataset")
                            break
                    total_bleu /= total_size
                    total_loss /= total_size
                    tf.logging.info(
                        "{}, Step: {}, Valid loss: {:.6f}, Valid bleu : {:.6f}"
                        .format(datetime.now(), step, total_loss, total_bleu))
                    tf.logging.info(
                        "--------------------- Finish evaluation -----------------------------------------------------"
                    )
                    # Save the model checkpoint periodically.
                    if step == 0:
                        total_bleu = 0.0

                    if total_bleu > best_bleu:
                        best_bleu = total_bleu
                        checkpoint_path = os.path.join(flags_obj.model_dir,
                                                       'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
                        tf.logging.info(
                            "Saving model at {}".format(checkpoint_path + "-" +
                                                        str(step)))
Example #5
0
def build_graph(params):
    my_dataset = dataset.Dataset(params)
    train_iterator = my_dataset.train_input_fn(params)
    valid_iterator = my_dataset.eval_input_fn(params)

    ckpt = tf.train.latest_checkpoint(flags_obj.model_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt):
        init_step = int(
            tf.train.latest_checkpoint(flags_obj.model_dir).split("-")[-1])
        global_step = tf.get_variable('global_step',
                                      initializer=init_step,
                                      trainable=False)
    else:
        init_step = 0
        global_step = tf.Variable(init_step,
                                  trainable=False,
                                  name="global_step")

    learning_rate = get_learning_rate(params.learning_rate, params.hidden_size,
                                      params.learning_rate_warmup_steps,
                                      global_step)

    optimizer = tf.contrib.opt.LazyAdamOptimizer(
        learning_rate,
        beta1=params.optimizer_adam_beta1,
        beta2=params.optimizer_adam_beta2,
        epsilon=params.optimizer_adam_epsilon)

    tower_grads = []
    g_tower_grads = []
    g_model = gen_and_dis.Generator(params,
                                    is_train=True,
                                    name_scope="Transformer")
    d_model = gen_and_dis.Discriminator(params,
                                        is_train=True,
                                        name_scope="Discriminator")
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        for i in xrange(flags_obj.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope:
                    tf.logging.info("Build graph on gpu:{}".format(i))
                    # pretrain loss
                    logits = g_model.inference(train_iterator.source,
                                               train_iterator.target)
                    xentropy, weights = metrics.padded_cross_entropy_loss(
                        logits, train_iterator.target, params.label_smoothing,
                        params.target_vocab_size)
                    xen_loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)

                    # g_loss
                    gen_samples = g_model.inference(train_iterator.source,
                                                    None)["outputs"]
                    deal_samples = train_helper._trim_and_pad(gen_samples)
                    given_num, rewards, roll_mean_loss, real_mean_loss = g_model.get_reward(
                        real_inputs=train_iterator.source,
                        real_targets=train_iterator.target,
                        gen_targets=deal_samples,
                        roll_num=flags_obj.roll_num,
                        discriminator=d_model)
                    g_loss = g_model.g_loss(gen_targets=deal_samples,
                                            given_num=given_num,
                                            rewards=rewards)

                    xen_grads = optimizer.compute_gradients(xen_loss)
                    gen_grads = optimizer.compute_gradients(g_loss)

                    g_grads = []
                    x_grads = []
                    for grad, var in gen_grads:
                        if "Transformer" in var.name:
                            g_grads.append((grad, var))
                    for grad, var in xen_grads:
                        if "Transformer" in var.name:
                            x_grads.append((grad, var))

                    tf.logging.info(
                        "total trainable variables number: {}, {}".format(
                            len(g_grads), len(x_grads)))
                    tower_grads.append(x_grads)
                    g_tower_grads.append(g_grads)

                if i == 0 and valid_iterator:
                    val_pred = g_model.inference(inputs=valid_iterator.source,
                                                 targets=None)["outputs"]

    if len(tower_grads) > 1:
        print(len(tower_grads[0]), len(tower_grads[1]))
        x_grads = train_helper.average_gradients(tower_grads)
        g_grads = train_helper.average_gradients(g_tower_grads)
    else:
        x_grads = tower_grads[0]
        g_grads = g_tower_grads[0]

    apply_gradient_op = optimizer.apply_gradients(x_grads,
                                                  global_step=global_step)
    g_apply_gradient_op = optimizer.apply_gradients(g_grads,
                                                    global_step=global_step)

    train_op = tf.group(apply_gradient_op, g_apply_gradient_op)

    train_return = (train_op, global_step, g_loss, xen_loss, rewards,
                    learning_rate, init_step, roll_mean_loss, real_mean_loss)
    valid_return = (val_pred, valid_iterator.target, valid_iterator.source)
    dataset_iter = (train_iterator, valid_iterator)
    return g_model, d_model, train_return, valid_return, dataset_iter
Example #6
0
def main():
    best_metric = 0
    train_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])
    val_transform_det = trans.Compose([
        trans.Scale(cfg.TRANSFROM_SCALES),
    ])

    train_data = my_dataset.Dataset(cfg.TRAIN_DATA_PATH,
                                    cfg.TRAIN_LABEL_PATH,
                                    cfg.TRAIN_TXT_PATH,
                                    'train',
                                    transform=True,
                                    transform_med=train_transform_det)
    val_data = my_dataset.Dataset(cfg.VAL_DATA_PATH,
                                  cfg.VAL_LABEL_PATH,
                                  cfg.VAL_TXT_PATH,
                                  'val',
                                  transform=True,
                                  transform_med=val_transform_det)
    train_dataloader = DataLoader(train_data,
                                  batch_size=cfg.BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)
    val_dataloader = DataLoader(val_data,
                                batch_size=cfg.BATCH_SIZE,
                                shuffle=False,
                                num_workers=8,
                                pin_memory=True)

    model = SiamUNet(in_ch=3)
    if cfg.RESUME:
        checkpoint = torch.load(cfg.TRAINED_LAST_MODEL)
        model.load_state_dict(checkpoint['state_dict'])
        print('resume success \n')
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)

    if torch.cuda.is_available():
        model.cuda()

    # if torch.cuda.is_available():
    #     model.cuda()

    # params = [{'params': md.parameters()} for md in model.children() if md in [model.classifier]]
    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.INIT_LEARNING_RATE,
                           weight_decay=cfg.DECAY)
    fl = FocalLoss2d(gamma=cfg.FOCAL_LOSS_GAMMA)
    Loss_list = []
    Accuracy_list = []
    scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
    for epoch in range(cfg.EPOCH):
        scheduler.step()
        print('epoch {}'.format(epoch + 1))
        #training--------------------------
        train_loss = 0
        train_acc = 0
        for batch_idx, train_batch in enumerate(train_dataloader):
            model.train()
            batch_x1, batch_x2, batch_y, _, _, _ = train_batch
            batch_x1, batch_x2, batch_y = Variable(batch_x1).cuda(), Variable(
                batch_x2).cuda(), Variable(batch_y).cuda()
            outputs = model(batch_x1, batch_x2)
            del batch_x1, batch_x2
            loss = calc_loss_L4(outputs[0], outputs[1], outputs[2], outputs[3],
                                batch_y)
            # train_loss += loss.data[0]
            #should change after
            # pred = torch.max(out, 1)[0]
            # train_correct = (pred == batch_y).sum()
            # train_acc += train_correct.data[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (batch_idx) % 5 == 0:
                model.eval()
                val_loss = 0
                for v_batch_idx, val_batch in enumerate(val_dataloader):
                    v_batch_x1, v_batch_x2, v_batch_y, _, _, _ = val_batch
                    v_batch_x1, v_batch_x2, v_batch_y = Variable(
                        v_batch_x1).cuda(), Variable(
                            v_batch_x2).cuda(), Variable(v_batch_y).cuda()
                    val_outs = model(v_batch_x1, v_batch_x2)
                    del v_batch_x1, v_batch_x2
                    val_loss += float(
                        calc_loss_L4(val_outs[0], val_outs[1], val_outs[2],
                                     val_outs[3], v_batch_y))
                del val_outs, v_batch_y
                print("Train Loss: {:.6f}  Val Loss: {:.10f}".format(
                    loss, val_loss))

        if (epoch + 1) % 5 == 0:
            torch.save({'state_dict': model.state_dict()},
                       os.path.join(cfg.SAVE_MODEL_PATH, cfg.TRAIN_LOSS,
                                    'model_tif_' + str(epoch + 1) + '.pth'))
    torch.save({'state_dict': model.state_dict()},
               os.path.join(cfg.SAVE_MODEL_PATH, cfg.TRAIN_LOSS,
                            'model_tif_last.pth'))
def train(params):
    with tf.Graph().as_default():
        if tf.train.latest_checkpoint(flags_obj.model_dir):
            global_step_value = int(
                tf.train.latest_checkpoint(flags_obj.model_dir).split("-")[-1])
            global_step = tf.Variable(initial_value=global_step_value,
                                      dtype=tf.int32,
                                      trainable=False)
            print(
                "right here!",
                int(
                    tf.train.latest_checkpoint(
                        flags_obj.model_dir).split("-")[-1]))
        else:
            global_step_value = 0
            global_step = tf.get_variable(
                'global_step', [],
                initializer=tf.constant_initializer(0),
                trainable=False)
        learning_rate = get_learning_rate(params.learning_rate,
                                          params.hidden_size,
                                          params.learning_rate_warmup_steps,
                                          global_step)

        optimizer = tf.contrib.opt.LazyAdamOptimizer(
            learning_rate,
            beta1=params.optimizer_adam_beta1,
            beta2=params.optimizer_adam_beta2,
            epsilon=params.optimizer_adam_epsilon)

        my_dataset = dataset.Dataset(params)

        train_iterator = my_dataset.train_input_fn(params)
        valid_iterator = my_dataset.eval_input_fn(params)

        tower_grads = []
        g_model = transformer_9.Transformer(params,
                                            is_train=True,
                                            mode=None,
                                            scope="Transformer")
        with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            for i in xrange(flags_obj.num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope:
                        tf.logging.info("Build graph on gpu:{}".format(i))
                        logits = g_model.inference(train_iterator.source,
                                                   train_iterator.target)
                        xentropy, weights = metrics.padded_cross_entropy_loss(
                            logits, train_iterator.target,
                            params.label_smoothing, params.target_vocab_size)
                        loss = tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)
                        grads = optimizer.compute_gradients(loss)
                        tf.logging.info(
                            "total trainable variables number: {}".format(
                                len(grads)))
                        tower_grads.append(grads)
                    if i == 0 and valid_iterator:
                        valid_pred = g_model.inference(
                            inputs=valid_iterator.source,
                            targets=None)["outputs"]
                        valid_tgt = valid_iterator.target
                        valid_src = valid_iterator.source

        if len(tower_grads) > 1:
            grads = average_gradients(tower_grads)
        else:
            grads = tower_grads[0]
        summaries.append(tf.summary.scalar('learning_rate', learning_rate))
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))
        apply_gradient_op = optimizer.apply_gradients(grads,
                                                      global_step=global_step)
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))
        train_op = apply_gradient_op

        saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=20)

        init = tf.global_variables_initializer()
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess_config.allow_soft_placement = True

        with tf.Session(config=sess_config) as sess:
            sess.run(init)
            sess.run(tf.local_variables_initializer())

            sess.run(train_iterator.initializer)

            ckpt = tf.train.latest_checkpoint(flags_obj.model_dir)
            tf.logging.info("ckpt {}".format(ckpt))
            if ckpt and tf.train.checkpoint_exists(ckpt):
                tf.logging.info(
                    "Reloading model parameters..from {}".format(ckpt))
                saver.restore(sess, ckpt)
            else:
                tf.logging.info("create a new model...{}".format(
                    flags_obj.model_dir))
            tf.train.start_queue_runners(sess=sess)
            summary_writer = tf.summary.FileWriter(flags_obj.model_dir,
                                                   sess.graph)

            count = 0
            best_bleu = 0.0
            for step in xrange(global_step_value, flags_obj.train_steps):
                _, loss_value, lr_value = sess.run(
                    [train_op, loss, learning_rate],
                    feed_dict={g_model.dropout_rate: 0.1})
                if step % 200 == 0:
                    tf.logging.info(
                        "step: {}, loss = {:.4f}, lr = {:5f}".format(
                            step, loss_value, lr_value))

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if step < 10000:
                    steps_between_evals = 2000
                else:
                    steps_between_evals = 1000
                if step % steps_between_evals == 0:
                    sess.run(valid_iterator.initializer)
                    tf.logging.info(
                        "------------------ Evaluation bleu -------------------------"
                    )
                    total_bleu = 0.0
                    total_size = 0
                    while True:
                        try:
                            val_pred, val_tgt, val_src = sess.run(
                                [valid_pred, valid_tgt, valid_src],
                                feed_dict={g_model.dropout_rate: 0.0})
                            val_bleu = metrics.compute_bleu(val_tgt, val_pred)
                            batch_size = val_pred.shape[0]
                            total_bleu += val_bleu * batch_size
                            total_size += batch_size
                        except tf.errors.OutOfRangeError:
                            break
                    total_bleu /= total_size
                    tf.logging.info("{}, Step: {}, Valid bleu : {:.6f}".format(
                        datetime.now(), step, total_bleu))
                    tf.logging.info(
                        "--------------------- Finish evaluation ------------------------"
                    )
                    # Save the model checkpoint periodically.
                    if step == 0:
                        total_bleu = 0.0

                    if total_bleu > best_bleu:
                        best_bleu = total_bleu
                        checkpoint_path = os.path.join(flags_obj.model_dir,
                                                       'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
                        tf.logging.info(
                            "Saving model at {}".format(checkpoint_path + "-" +
                                                        str(step)))
                    elif total_bleu + 0.003 > best_bleu:
                        checkpoint_path = os.path.join(flags_obj.model_dir,
                                                       'model.ckpt')
                        saver.save(sess, checkpoint_path, global_step=step)
                        tf.logging.info(
                            "Saving model at {}".format(checkpoint_path + "-" +
                                                        str(step)))
                    else:
                        count += 1
                        # early stop
                        if count > 5:
                            break
            tf.logging.info("Best bleu is {}".format(best_bleu))
import sys
sys.path.append('')
from utils import dataset as ds


d = ds.Dataset('data/split/equiv.train', 'data/vocab.txt')
print(len(d))
p = d._permutation[:10]
for i in p:
    print(d.formulae_1[i], d.formulae_2[i], d.labels[i])
n = d.next_batch(10)
print(n)
print(d.num_tokens)