Beispiel #1
0
    def __init__(self, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        #         self.encoder_q = base_encoder(num_classes=dim, pretrained=True)
        #         self.encoder_k = base_encoder(num_classes=dim, pretrained=True)

        self.encoder_q = densenet169(pretrained=True)
        self.encoder_k = densenet169(pretrained=True)
        fc_features = self.encoder_q.classifier.in_features
        self.encoder_q.classifier = nn.Linear(fc_features, dim)
        self.encoder_k.classifier = nn.Linear(fc_features, dim)

        #         self.encoder_q.classifier.weight = self.encoder_q.classifier.weight[:128, :]
        #         self.encoder_q.classifier.weight = self.encoder_q.classifier.weight[:128, :]
        if mlp:  # hack: brute-force replacement
            #             dim_mlp = self.encoder_q.fc.weight.shape[1]
            dim_mlp = self.encoder_q.classifier.weight.shape[1]
            #             self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
            self.encoder_q.classifier = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(),
                self.encoder_q.classifier)
            #             self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
            self.encoder_k.classifier = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(),
                self.encoder_k.classifier)

        for param_q, param_k in zip(self.encoder_q.parameters(),
                                    self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
Beispiel #2
0
def train(args):
    batch_size = args.batch_size
    epoch = args.epoch
    network = args.network
    opt = args.opt
    train = unpickle(args.train_path)
    test = unpickle(args.test_path)
    train_data = train[b'data']
    test_data = test[b'data']

    x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    x_train = x_train.transpose(0, 2, 3, 1)
    y_train = train[b'fine_labels']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_train = norm_images(x_train)
    x_test = norm_images(x_test)

    print('-------------------------------')
    print('--train/test len: ', len(train_data), len(test_data))
    print('--x_train norm: ', compute_mean_var(x_train))
    print('--x_test norm: ', compute_mean_var(x_test))
    print('--batch_size: ', batch_size)
    print('--epoch: ', epoch)
    print('--network: ', network)
    print('--opt: ', opt)
    print('-------------------------------')

    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_train, y_train, './trans/', 'tran.tfrecords')
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')

    dataset = tf.data.TFRecordDataset('./trans/tran.tfrecords')
    dataset = dataset.map(parse_function)
    dataset = dataset.shuffle(buffer_size=50000)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [None, ])
    y_input_one_hot = tf.one_hot(y_input, 100)
    lr = tf.placeholder(tf.float32, [])

    if network == 'resnet50':
        prob = resnet50(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet34':
        prob = resnet34(x_input, is_training=True, reuse=False,
                        kernel_initializer=tf.contrib.layers.variance_scaling_initializer())
    elif network == 'resnet18':
        prob = resnet18(x_input, is_training=True, reuse=False,
                        kernel_initializer=tf.contrib.layers.variance_scaling_initializer())
    elif network == 'seresnet50':
        prob = se_resnet50(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet110':
        prob = resnet110(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet110':
        prob = se_resnet110(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet152':
        prob = se_resnet152(x_input, is_training=True, reuse=False, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet152':
        prob = resnet152(x_input, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet_fixed':
        prob = get_resnet(x_input, 152, trainable=True, w_init=tf.orthogonal_initializer())
    elif network == 'densenet121':
        prob = densenet121(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet169':
        prob = densenet169(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet201':
        prob = densenet201(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet161':
        prob = densenet161(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet100bc':
        prob = densenet100bc(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet190bc':
        prob = densenet190bc(x_input, reuse=False, is_training=True, kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext50':
        prob = resnext50(x_input, reuse=False, is_training=True, cardinality=32,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext110':
        prob = resnext110(x_input, reuse=False, is_training=True, cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext152':
        prob = resnext152(x_input, reuse=False, is_training=True, cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext50':
        prob = se_resnext50(x_input, reuse=False, is_training=True, cardinality=32,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext110':
        prob = se_resnext110(x_input, reuse=False, is_training=True, cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext152':
        prob = se_resnext152(x_input, reuse=False, is_training=True, cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prob, labels=y_input_one_hot))

    conv_var = [var for var in tf.trainable_variables() if 'conv' in var.name]
    l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in conv_var])
    loss = l2_loss * 5e-4 + loss

    if opt == 'adam':
        opt = tf.train.AdamOptimizer(lr)
    elif opt == 'momentum':
        opt = tf.train.MomentumOptimizer(lr, 0.9)
    elif opt == 'nesterov':
        opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = opt.minimize(loss)

    logit_softmax = tf.nn.softmax(prob)
    acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logit_softmax, 1), y_input), tf.float32))

    # -------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input, 152, type='se_ir', trainable=False, reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input, is_training=False, reuse=True, kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input, reuse=True, is_training=False, kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input, reuse=True, is_training=False, kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input, is_training=False, reuse=True, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input, reuse=True, is_training=False, cardinality=32, kernel_initializer=None)

    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input), tf.float32))
    # ----------------------------------------------------------------------------
    saver = tf.train.Saver(max_to_keep=1, var_list=tf.global_variables())
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    now_lr = 0.001  # Warm Up
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        counter = 0
        max_test_acc = -1
        for i in range(epoch):
            sess.run(iterator.initializer)
            while True:
                try:
                    batch_train, label_train = sess.run(next_element)
                    _, loss_val, acc_val, lr_val = sess.run([train_op, loss, acc, lr],
                                                            feed_dict={x_input: batch_train, y_input: label_train,
                                                                       lr: now_lr})

                    counter += 1

                    if counter % 100 == 0:
                        print('counter: ', counter, 'loss_val', loss_val, 'acc: ', acc_val)
                    if counter % 1000 == 0:
                        print('start test ')
                        sess.run(iterator_test.initializer)
                        avg_acc = []
                        while True:
                            try:
                                batch_test, label_test = sess.run(next_element_test)
                                acc_test_val = sess.run(acc_test, feed_dict={x_input: batch_test, y_input: label_test})
                                avg_acc.append(acc_test_val)
                            except tf.errors.OutOfRangeError:
                                print('end test ', np.sum(avg_acc) / len(y_test))
                                now_test_acc = np.sum(avg_acc) / len(y_test)
                                if now_test_acc > max_test_acc:
                                    print('***** Max test changed: ', now_test_acc)
                                    max_test_acc = now_test_acc
                                    filename = 'params/distinct/' + network + '_{}.ckpt'.format(counter)
                                    # saver.save(sess, filename)
                                break
                except tf.errors.OutOfRangeError:
                    print('end epoch %d/%d , lr: %f' % (i, epoch, lr_val))
                    now_lr = lr_schedule(i, args.epoch)
                    break
                        if_shake_shake=args.shake,
                        first_conv_stride=args.first_stride,
                        first_pool=True).to(device)
elif args.model == "resnet152":
    print("SeNet152")
    model = se_resnet152(2, if_mixup=args.mixup,
                         if_shake_shake=args.shake).to(device)
elif args.model == "densenet121":
    print("DenseNet121")
    model = densenet121(if_mixup=args.mixup,
                        if_selayer=args.se,
                        first_conv_stride=args.first_stride,
                        first_pool=True).to(device)
elif args.model == "densenet169":
    print("DenseNet169")
    model = densenet169(if_mixup=args.mixup, if_selayer=args.se).to(device)
elif args.model == "densenet201":
    print("DenseNet201")
    model = densenet201(if_mixup=args.mixup, if_selayer=args.se).to(device)
elif args.model == "dpn92":
    print("DPN92")
    model = dpn92(num_classes=2, if_selayer=args.se).to(device)
elif args.model == "dpn98":
    print("DPN98")
    model = dpn98(num_classes=2, if_selayer=args.se).to(device)
elif args.model == "dpn131":
    print("DPN131")
    model = dpn131(num_classes=2, if_selayer=args.se).to(device)

MODEL_PATH = os.path.join("model", args.directory, model_name)
checkpoint = torch.load(MODEL_PATH, map_location=device)
Beispiel #4
0
def test(args):
    # train = unpickle('/data/ChuyuanXiong/up/cifar-100-python/train')
    # train_data = train[b'data']
    # x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    # x_train = x_train.transpose(0, 2, 3, 1)

    test = unpickle(args.test_path)
    test_data = test[b'data']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_test = norm_images(x_test)
    # x_test = norm_images_using_mean_var(x_test, *compute_mean_var(x_train))

    network = args.network
    ckpt = args.ckpt

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [None, ])
    # -------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/test.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input, 152, type='se_ir', trainable=False, reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input, is_training=False, reuse=False, kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input, reuse=False, is_training=False, kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input, reuse=False, is_training=False, kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input, is_training=False, reuse=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input, reuse=False, is_training=False, cardinality=32, kernel_initializer=None)

    # prob_test = tf.layers.dense(prob_test, 100, reuse=True, name='before_softmax')
    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input), tf.float32))

    var_list = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars

    saver = tf.train.Saver(var_list=var_list)
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        saver.restore(sess, ckpt)
        sess.run(iterator_test.initializer)
        avg_acc = []
        while True:
            try:
                batch_test, label_test = sess.run(next_element_test)
                acc_test_val = sess.run(acc_test, feed_dict={x_input: batch_test, y_input: label_test})
                avg_acc.append(acc_test_val)
            except tf.errors.OutOfRangeError:
                print('end test ', np.sum(avg_acc) / len(y_test))
                break
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', default=50, type=int, help='epoch number')
    parser.add_argument('-b',
                        '--batch_size',
                        default=256,
                        type=int,
                        help='mini-batch size')
    parser.add_argument('--lr',
                        '--learning_rate',
                        default=1e-3,
                        type=float,
                        help='initial learning rate')
    parser.add_argument('-c',
                        '--continue',
                        dest='continue_path',
                        type=str,
                        required=False)
    parser.add_argument('--exp_name',
                        default=config.exp_name,
                        type=str,
                        required=False)
    parser.add_argument('--drop_rate', default=0, type=float, required=False)
    parser.add_argument('--only_fc',
                        action='store_true',
                        help='only train fc layers')
    parser.add_argument('--net',
                        default='densenet169',
                        type=str,
                        required=False)
    parser.add_argument('--local',
                        action='store_true',
                        help='train local branch')
    args = parser.parse_args()
    args.batch_size = 32
    args.epochs = 150
    args.net = 'densenet169'
    print(args)

    config.exp_name = args.exp_name
    config.make_dir()
    save_args(args, config.log_dir)

    # get network
    if args.net == 'resnet50':
        net = resnet50(pretrained=True, drop_rate=args.drop_rate)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=True, drop_rate=args.drop_rate)
    elif args.net == 'densenet121':
        net = models.densenet121(pretrained=True)
        net.classifier = nn.Sequential(nn.Linear(1024, 1), nn.Sigmoid())
    elif args.net == 'densenet169':
        net = densenet169(pretrained=True, drop_rate=args.drop_rate)
    elif args.net == 'fusenet':
        global_branch = torch.load(GLOBAL_BRANCH_DIR)['net']
        local_branch = torch.load(LOCAL_BRANCH_DIR)['net']
        net = fusenet(global_branch, local_branch)
        del global_branch, local_branch
    else:
        raise NameError

    net = net.cuda()
    sess = Session(config, net=net)

    # get dataloader
    # train_loader = get_dataloaders('train', batch_size=args.batch_size,
    #                                shuffle=True, is_local=args.local)
    #
    # valid_loader = get_dataloaders('valid', batch_size=args.batch_size,
    #                                shuffle=False, is_local=args.local)
    train_loader = get_dataloaders('train',
                                   batch_size=args.batch_size,
                                   num_workers=4,
                                   shuffle=True)

    valid_loader = get_dataloaders('valid',
                                   batch_size=args.batch_size,
                                   shuffle=False)

    if args.continue_path and os.path.exists(args.continue_path):
        sess.load_checkpoint(args.continue_path)

    # start session
    clock = sess.clock
    tb_writer = sess.tb_writer
    sess.save_checkpoint('start.pth.tar')

    # set criterion, optimizer and scheduler
    criterion = nn.BCELoss().cuda()  # not used

    if args.only_fc:
        optimizer = optim.Adam(sess.net.module.classifier.parameters(),
                               args.lr)
    else:
        optimizer = optim.Adam(sess.net.parameters(), args.lr)

    scheduler = ReduceLROnPlateau(optimizer,
                                  'max',
                                  factor=0.1,
                                  patience=10,
                                  verbose=True)

    # start training
    for e in range(args.epochs):
        train_out = train_model(train_loader, sess.net, criterion, optimizer,
                                clock.epoch)
        valid_out = valid_model(valid_loader, sess.net, criterion, optimizer,
                                clock.epoch)

        tb_writer.add_scalars('loss', {
            'train': train_out['epoch_loss'],
            'valid': valid_out['epoch_loss']
        }, clock.epoch)

        tb_writer.add_scalars('acc', {
            'train': train_out['epoch_acc'],
            'valid': valid_out['epoch_acc']
        }, clock.epoch)

        tb_writer.add_scalar('auc', valid_out['epoch_auc'], clock.epoch)

        tb_writer.add_scalar('learning_rate', optimizer.param_groups[-1]['lr'],
                             clock.epoch)
        scheduler.step(valid_out['epoch_auc'])

        if valid_out['epoch_auc'] > sess.best_val_acc:
            sess.best_val_acc = valid_out['epoch_auc']
            sess.save_checkpoint('best_model.pth.tar')

        if clock.epoch % 10 == 0:
            sess.save_checkpoint('epoch{}.pth.tar'.format(clock.epoch))
        sess.save_checkpoint('latest.pth.tar')

        clock.tock()