Пример #1
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--lr',
                        default=0.001,
                        type=float,
                        help='learning rate')
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        help='resume from checkpoint')
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    present_time = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    input_type = "only_one_eye"

    print("Prepare directory...")
    prepare_dir(present_time, input_type)

    print("Prepare model...")
    net = EfficientNetB0().cuda()

    best_acc, start_epoch = 0, 0
    if args.resume:
        print("Loading checkpoint...")
        net, best_acc, start_epoch = load_parameter(net, date)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-4)
    # optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    train_img_list = glob.glob(f"./../eye_data/{input_type}/train/*")
    test_img_list = glob.glob(f"./../eye_data/{input_type}/test/*")
    random.shuffle(train_img_list)

    torch.save(
        net.state_dict(), f"./weights/{present_time}/init_weight.pth"
    )  # 先儲存初始的 weight, 5-fold cross validation 每次都需要先 load 初始 weigth
    total_best_valid_acc, total_best_valid_loss = 0, 0  # 用來算在 5-fold cross validation 上 accuracy 和 loss 的表現

    # Training: use 5-fold cross validation to test the generalizability
    k = 10
    k_fold_cross_validation(net, optimizer, criterion, train_img_list, k,
                            present_time)

    # Training: use the entire training set to train the model
    net.load_state_dict(
        torch.load(f"./weights/{present_time}/init_weight.pth"))
    random.shuffle(train_img_list)
    net = final_training(net, optimizer, criterion, train_img_list,
                         present_time)

    # Testing: Evaluate the trained model on the test set
    print("\n========= result on testing set =========\n")
    print("testing set length info:", len(test_img_list))
    final_testing(net, criterion, test_img_list)
    store_parameter(120, net, optimizer, -1, -1, present_time)

    CAM(net, test_img_list, present_time, input_type)
Пример #2
0
def plot_conv_output(conv_img, name, filters_all=True, filters=[0]):
    """
    Makes plots of results of performing convolution
    :param conv_img: numpy array of rank 4
    :param name: string, name of convolutional layer
    :return: nothing, plots are saved on the disk
    """
    # make path to output folder
    plot_dir = os.path.join(PLOT_DIR, 'conv_output')
    plot_dir = os.path.join(plot_dir, name)

    # create directory if does not exist, otherwise empty it
    utils.prepare_dir(plot_dir, empty=True)

    w_min = np.min(conv_img)
    w_max = np.max(conv_img)

    # get number of convolutional filters
    if filters_all:

        num_filters = conv_img.shape[3]
        filters = range(conv_img.shape[3])
    else:
        num_filters = len(filters)

    # get number of grid rows and columns
    grid_r, grid_c = utils.get_grid_dim(num_filters)

    # create figure and axes
    fig, axes = plt.subplots(min([grid_r, grid_c]), max([grid_r, grid_c]))

    if num_filters == 1:
        img = conv_img[0, :, :, filters[0]]
        axes.imshow(img,
                    vmin=w_min,
                    vmax=w_max,
                    interpolation='bicubic',
                    cmap='Greys')
        # remove any labels from the axes
        axes.set_xticks([])
        axes.set_yticks([])

    # iterate filters
    else:
        for l, ax in enumerate(axes.flat):
            # get a single image
            img = conv_img[0, :, :, filters[l]]
            # put it on the grid
            ax.imshow(img,
                      vmin=w_min,
                      vmax=w_max,
                      interpolation='bicubic',
                      cmap='Greys')
            # remove any labels from the axes
            ax.set_xticks([])
            ax.set_yticks([])
    # save figure
    plt.savefig(os.path.join(plot_dir, '{}.png'.format(name)),
                bbox_inches='tight')
Пример #3
0
def plot_conv_weights(weights, name, channels_all=True):
    """
    Plots convolutional filters

    :param weights: numpy array of rank 4

    :param name: string, name of convolutional layer

    :param channels_all: boolean, optional

    :return: nothing, plots are saved on the disk

    """
    # make path to output folder
    plot_dir = os.path.join(PLOT_DIR, 'conv_weights')
    plot_dir = os.path.join(plot_dir, name)
    # create directory if does not exist, otherwise empty it
    utils.prepare_dir(plot_dir, empty=True)
    w_min = np.min(weights)
    w_max = np.max(weights)
    channels = [0]
    # make a list of channels if all are plotted
    if channels_all:
        channels = range(weights.shape[2])

    # get number of convolutional filters
    num_filters = weights.shape[3]
    # get number of grid rows and columns
    grid_r, grid_c = utils.get_grid_dim(num_filters)
    # create figure and axes
    fig, axes = plt.subplots(min([grid_r, grid_c]), max([grid_r, grid_c]))

    # iterate channels
    for channel in channels:

        # iterate filters inside every channel

        for l, ax in enumerate(axes.flat):

            # get a single filter

            img = weights[:, :, channel, l]

            # put it on the grid

            ax.imshow(img,
                      vmin=w_min,
                      vmax=w_max,
                      interpolation='nearest',
                      cmap='seismic')
            # remove any labels from the axes
            ax.set_xticks([])
            ax.set_yticks([])

        # save figure

        plt.savefig(os.path.join(plot_dir, '{}-{}.png'.format(name, channel)),
                    bbox_inches='tight')
Пример #4
0
 def update_rois(self):
     
     data = self.train_data.copy()
     data.update(self.test_data)
     
     batch_size = 1
     roi_dir = prepare_dir(self.roi_dir)
     cam_dir = prepare_dir('./cams')
     print(roi_dir)
     train_file = './TrainImages.label'
     #train_file = './SUNTrainImages.txt'
     test_file = './TestImages.label'
     #test_file = './SUNTestImages.txt'
     train_iter, train_data = load_data(self.img_dir, roi_dir, '', train_file, batch_size, False, cam_dir=cam_dir)
     test_iter, test_data = load_data(self.img_dir, roi_dir, '', test_file, batch_size, False, cam_dir=cam_dir)
     x = tf.placeholder(tf.float32, [None, imgsize, imgsize, 3])
     cate_id = tf.placeholder(tf.int32)
     fdict = {'x': x, 'cate_id': cate_id}
     with tf.variable_scope("graph_model", reuse = tf.AUTO_REUSE):
         model_out = self.run_cnn(x)
     # Configuration of GPU usage
     config = tf.ConfigProto()
     # config.gpu_options.per_process_gpu_memory_fraction = 0.7
     config.gpu_options.allow_growth = True
     with tf.Session(config=config) as sess:
         sess.run(tf.global_variables_initializer())
         sess.run(tf.local_variables_initializer())
         # extract salient regions
         sess.run(train_iter.initializer)
         count = 0
         vstart_time = time.time()
         while True:
             try:
                 count += 1
                 batch_input = sess.run(train_data)
                 self.extract_salient_region_scda_adapt(sess, model_out, fdict, batch_input)
                 if count % 1000 == 0:
                     print('{} images are done, {:.4f}s per image'.format(
                         count, (time.time()-vstart_time) / count))
             except tf.errors.OutOfRangeError:
                 break
         # extract test feature
         sess.run(test_iter.initializer)
         count = 0
         vstart_time = time.time()
         while True:
             try:
                 count += 1
                 batch_input = sess.run(test_data)
                 self.extract_salient_region_scda_adapt(sess, model_out, fdict, batch_input)
                 if count % 1000 == 0:
                     print('{} images are done, {:.4f}s per image'.format(
                         count, (time.time() - vstart_time) / count))
             except tf.errors.OutOfRangeError:
                 break
Пример #5
0
def plot_conv_weights(weights, name, channels_all=True):
    """
    Plots convolutional filters
    :param weights: numpy array of rank 4
    :param name: string, name of convolutional layer
    :param channels_all: boolean, optional
    :return: nothing, plots are saved on the disk
    """
    # make path to output folder
    plot_dir = os.path.join(PLOT_DIR, 'conv_weights')
    plot_dir = os.path.join(plot_dir, name)

    # create directory if does not exist, otherwise empty it
    utils.prepare_dir(plot_dir, empty=True)

    w_min = np.min(weights)
    w_max = np.max(weights)

    channels = [0]
    # make a list of channels if all are plotted
    if channels_all:
        channels = range(weights.shape[2])

    # get number of convolutional filters
    num_filters = weights.shape[3]

    # get number of grid rows and columns
    grid_r, grid_c = utils.get_grid_dim(num_filters)

    # create figure and axes
    fig, axes = plt.subplots(min([grid_r, grid_c]),
                             max([grid_r, grid_c]))

    # iterate channels
    for channel in channels:
        # iterate filters inside every channel
        for l, ax in enumerate(axes.flat):
            # get a single filter
            img = weights[:, :, channel, l]
            # put it on the grid
            ax.imshow(img, vmin=w_min, vmax=w_max, interpolation='nearest', cmap='seismic')
            # remove any labels from the axes
            ax.set_xticks([])
            ax.set_yticks([])
        # save figure
        plt.savefig(os.path.join(plot_dir, '{}-{}.png'.format(name, channel)), bbox_inches='tight')
Пример #6
0
def main(args):
    print("Begin data preparation.")
    s_train, t_train, s_test, t_test = prepare_data(args)
    print("Finish data preparation.")
    print("Begin building models.")
    encoder = office.Encoder()
    chainer.serializers.load_npz('alexnet.npz', encoder, strict=False)
    bottleneck = office.Bottleneck()
    classifier = office.Classifier()
    do_classifier = office.DomainClassifier()
    print("Finish building models.")

    encoder_opt = setup_optimizer(encoder, args.optimizer, args.lr)
    bottleneck_opt = setup_optimizer(bottleneck, args.optimizer, 10 * args.lr)
    classifier_opt = setup_optimizer(classifier, args.optimizer, 10 * args.lr)
    do_classifier_opt = setup_optimizer(do_classifier, args.optimizer,
                                        10 * args.lr)
    optimizers = {
        'encoder': encoder_opt,
        'domain_classifier': do_classifier_opt,
        'bottleneck': bottleneck_opt,
        'classifier': classifier_opt
    }
    loss_list = ['loss_cla_s', 'loss_cla_t', 'loss_do_cla']
    target_model = LossAndAccuracy(encoder, classifier, bottleneck)

    updater = Updater(s_train, t_train, optimizers, args)
    out_dir = utils.prepare_dir(args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=out_dir)
    trainer.extend(extensions.LogReport(trigger=(args.interval, args.unit)))
    for name, opt in optimizers.items():
        trainer.extend(extensions.snapshot_object(opt.target, filename=name),
                       trigger=MaxValueTrigger('acc_t',
                                               (args.interval, args.unit)))
    trainer.extend(extensions.Evaluator(t_test,
                                        target_model,
                                        device=args.device),
                   trigger=(args.interval, args.unit))
    trainer.extend(
        extensions.PrintReport(
            [args.unit, *loss_list, 'acc_s', 'acc_t', 'elapsed_time']))
    trainer.extend(
        extensions.PlotReport([*loss_list],
                              x_key=args.unit,
                              file_name='loss.png',
                              trigger=(args.interval, args.unit)))
    trainer.extend(
        extensions.PlotReport(['acc_s', 'acc_t'],
                              x_key=args.unit,
                              file_name='accuracy.png',
                              trigger=(args.interval, args.unit)))
    trainer.extend(extensions.ProgressBar(update_interval=1))
    print("Start training loops.")
    trainer.run()
    print("Finish training loops.")
Пример #7
0
def save_model(encoder, decoder, plot_losses, model_name):
    stamp = str(time.time())
    savepath = utils.prepare_dir(model_name, stamp)
    torch.save(encoder.state_dict(), savepath + "/%s.encoder" % stamp)
    torch.save(decoder.state_dict(), savepath + "/%s.decoder" % stamp)
    try:
        utils.showPlot(plot_losses, model_name, stamp)
    except:
        pass
    print(" * model save with time stamp: ", stamp)
Пример #8
0
def plot_conv_output(conv_img, name):
    """
    Makes plots of results of performing convolution
    :param conv_img: numpy array of rank 4
    :param name: string, name of convolutional layer
    :return: nothing, plots are saved on the disk
    """
    # make path to output folder
    plot_dir = os.path.join(PLOT_DIR, 'feature_maps')
    plot_dir = os.path.join(plot_dir, name)
    plt.figure(num=None, figsize=(16, 12), dpi=80)
    # create directory if does not exist, otherwise empty it
    utils.prepare_dir(plot_dir, empty=True)

    #    w_min = np.min(conv_img)
    #    w_max = np.max(conv_img)

    # get number of convolutional filters
    num_filters = conv_img.shape[1]

    # get number of grid rows and columns
    grid_r, grid_c = utils.get_grid_dim(num_filters)

    # create figure and axes
    fig, axes = plt.subplots(min([grid_r, grid_c]),
                             max([grid_r, grid_c]),
                             figsize=(12, 8))

    # iterate filters
    for l, ax in enumerate(axes.flat):
        # get a single image
        img = conv_img[0, l, :, :]
        # put it on the grid
        ax.imshow(img, interpolation='bicubic', cmap='Greys')
        # remove any labels from the axes
        ax.set_xticks([])
        ax.set_yticks([])
    # save figure
    plt.savefig(os.path.join(plot_dir, '{}.png'.format(name)),
                bbox_inches='tight')
Пример #9
0
def main(args):
    print("Begin data preparation.")
    s_train, t_train, s_test, t_test = prepare_data(args)
    print("Finish data preparation.")
    print("Begin building models.")
    pixel_mean = mean_dict[args.source + '_' + args.target].astype('f')
    encoder = mnistm.Encoder(pixel_mean)
    classifier = mnistm.Classifier()
    do_classifier = mnistm.DomainClassifier()
    print("Finish building models.")

    encoder_opt = setup_optimizer(encoder, args.optimizer, args.lr)
    classifier_opt = setup_optimizer(classifier, args.optimizer, args.lr)
    do_classifier_opt = setup_optimizer(do_classifier, args.optimizer, args.lr)
    optimizers = {
        'encoder': encoder_opt,
        'domain_classifier': do_classifier_opt,
        'classifier': classifier_opt
    }
    loss_list = ['loss_cla_s', 'loss_cla_t', 'loss_do_cla']
    target_model = LossAndAccuracy(encoder, classifier)

    updater = Updater(s_train, t_train, optimizers, args)
    out_dir = utils.prepare_dir(args)
    trainer = training.Trainer(updater, (args.max_iter, 'iteration'),
                               out=out_dir)
    trainer.extend(extensions.LogReport(trigger=(args.interval, args.unit)))
    for name, opt in optimizers.items():
        trainer.extend(extensions.snapshot_object(opt.target, filename=name),
                       trigger=MaxValueTrigger('acc_t',
                                               (args.interval, args.unit)))
    trainer.extend(extensions.Evaluator(t_test,
                                        target_model,
                                        device=args.device),
                   trigger=(args.interval, args.unit))
    trainer.extend(
        extensions.PrintReport(
            [args.unit, *loss_list, 'acc_s', 'acc_t', 'elapsed_time']))
    trainer.extend(
        extensions.PlotReport([*loss_list],
                              x_key=args.unit,
                              file_name='loss.png',
                              trigger=(args.interval, args.unit)))
    trainer.extend(
        extensions.PlotReport(['acc_s', 'acc_t'],
                              x_key=args.unit,
                              file_name='accuracy.png',
                              trigger=(args.interval, args.unit)))
    trainer.extend(extensions.ProgressBar(update_interval=1))
    print("Start training loops.")
    trainer.run()
    print("Finish training loops.")
Пример #10
0
def plot_conv_output(conv_img, name):
    """
    Makes plots of results of performing convolution
    :param conv_img: numpy array of rank 4
    :param name: string, name of convolutional layer
    :return: nothing, plots are saved on the disk
    """
    # make path to output folder
    plot_dir = os.path.join(PLOT_DIR, 'conv_output')
    plot_dir = os.path.join(plot_dir, name)

    # create directory if does not exist, otherwise empty it
    utils.prepare_dir(plot_dir, empty=True)

    w_min = np.min(conv_img)
    w_max = np.max(conv_img)

    # get number of convolutional filters
    num_filters = conv_img.shape[3]

    # get number of grid rows and columns
    grid_r, grid_c = utils.get_grid_dim(num_filters)

    # create figure and axes
    fig, axes = plt.subplots(min([grid_r, grid_c]),
                             max([grid_r, grid_c]))

    # iterate filters
    for l, ax in enumerate(axes.flat):
        # get a single image
        img = conv_img[0, :, :,  l]
        # put it on the grid
        ax.imshow(img, vmin=w_min, vmax=w_max, interpolation='bicubic', cmap='Greys')
        # remove any labels from the axes
        ax.set_xticks([])
        ax.set_yticks([])
    # save figure
    plt.savefig(os.path.join(plot_dir, '{}.png'.format(name)), bbox_inches='tight')
Пример #11
0
def main(args):
    s_train, s_test = dataset.load_svhn()
    t_train, t_test = dataset.load_mnist()

    s_train_iter = SerialIterator(
        s_train, args.batchsize, shuffle=True, repeat=True)
    t_train_iter = SerialIterator(
        t_test, args.batchsize, shuffle=True, repeat=True)
    s_test_iter = SerialIterator(
        s_test, args.batchsize, shuffle=False, repeat=False)
    t_test_iter = SerialIterator(
        t_test, args.batchsize, shuffle=False, repeat=False)

    model = drcn.DRCN()
    target_model = LossAndAccuracy(model)
    loss_list = ['loss_cla_s', 'loss_cla_t', 'loss_rec']
    optimizer = chainer.optimizers.RMSprop(args.lr)
    optimizer.setup(model)
    optimizers = {
        'model': optimizer
    }

    updater = Updater(s_train_iter, t_train_iter, optimizers, args)
    out_dir = utils.prepare_dir(args)
    trainer = Trainer(updater, (args.max_iter, 'iteration'), out=out_dir)
    trainer.extend(extensions.LogReport(trigger=(args.interval, args.unit)))
    trainer.extend(
        extensions.snapshot_object(model, filename='model'),
        trigger=MaxValueTrigger('acc_t', (args.interval, args.unit)))
    trainer.extend(extensions.Evaluator(t_test_iter, target_model,
                                        device=args.device), trigger=(args.interval, args.unit))
    trainer.extend(extensions.PrintReport([args.unit, *loss_list, 'acc_s', 'acc_t', 'elapsed_time']))
    trainer.extend(extensions.PlotReport([*loss_list], x_key=args.unit, file_name='loss.png', trigger=(args.interval, args.unit)))
    trainer.extend(extensions.PlotReport(['acc_s', 'acc_t'], x_key=args.unit, file_name='accuracy.png', trigger=(args.interval, args.unit)))
    trainer.extend(extensions.ProgressBar(update_interval=1))
    trainer.run()
Пример #12
0
def train(cli_params):
    cli_params['save_dir'] = prepare_dir(cli_params['save_to'])
    logfile = os.path.join(cli_params['save_dir'], 'log.txt')

    # Log also DEBUG to a file
    fh = logging.FileHandler(filename=logfile)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    logger.info('Logging into %s' % logfile)

    p, loaded = load_and_log_params(cli_params)
    in_dim, data = setup_data(p, test_set=False)
    if not loaded:
        # Set the zero layer to match input dimensions
        p.encoder_layers = (in_dim, ) + p.encoder_layers

    ladder = setup_model(p)

    # Training
    all_params = ComputationGraph([ladder.costs.total]).parameters
    logger.info('Found the following parameters: %s' % str(all_params))

    # Fetch all batch normalization updates. They are in the clean path.
    bn_updates = ComputationGraph([ladder.costs.class_clean]).updates
    assert 'counter' in [u.name for u in list(bn_updates.keys())], \
        'No batch norm params in graph - the graph has been cut?'

    training_algorithm = GradientDescent(
        cost=ladder.costs.total,
        params=all_params,
        step_rule=Adam(learning_rate=ladder.lr))
    # In addition to actual training, also do BN variable approximations
    training_algorithm.add_updates(bn_updates)

    short_prints = {
        "train": {
            'T_C_class': ladder.costs.class_corr,
            'T_C_de': list(ladder.costs.denois.values()),
        },
        "valid_approx":
        OrderedDict([
            ('V_C_class', ladder.costs.class_clean),
            ('V_E', ladder.error.clean),
            ('V_C_de', list(ladder.costs.denois.values())),
        ]),
        "valid_final":
        OrderedDict([
            ('VF_C_class', ladder.costs.class_clean),
            ('VF_E', ladder.error.clean),
            ('VF_C_de', list(ladder.costs.denois.values())),
        ]),
    }

    main_loop = MainLoop(
        training_algorithm,
        # Datastream used for training
        make_datastream(data.train,
                        data.train_ind,
                        p.batch_size,
                        n_labeled=p.labeled_samples,
                        n_unlabeled=p.unlabeled_samples),
        model=Model(theano.tensor.cast(ladder.costs.total, "float32")),
        extensions=[
            FinishAfter(after_n_epochs=p.num_epochs),

            # This will estimate the validation error using
            # running average estimates of the batch normalization
            # parameters, mean and variance
            ApproxTestMonitoring(
                [ladder.costs.class_clean, ladder.error.clean] +
                list(ladder.costs.denois.values()),
                make_datastream(data.valid,
                                data.valid_ind,
                                p.valid_batch_size,
                                scheme=ShuffledScheme),
                prefix="valid_approx"),
            TrainingDataMonitoring([
                ladder.costs.total, ladder.costs.class_corr,
                training_algorithm.total_gradient_norm
            ] + list(ladder.costs.denois.values()),
                                   prefix="train",
                                   after_epoch=True),
            SaveParams(None, all_params, p.save_dir, after_epoch=True),
            SaveExpParams(p, p.save_dir, before_training=True),
            ShortPrinting(short_prints),
            LRDecay(ladder.lr,
                    p.num_epochs * p.lrate_decay,
                    p.num_epochs,
                    after_epoch=True),
        ])
    main_loop.run()

    # Get results
    df = main_loop.log.to_dataframe()
    col = 'valid_final_error_rate_clean'
    logger.info('%s %g' % (col, df[col].iloc[-1]))

    if main_loop.log.status['epoch_interrupt_received']:
        return None
    return df
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            # X_train_file=FLAGS.X,
            # Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf
        )
        # G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, x_correct, y_correct, fake_y_correct, fake_y_pre
        G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, \
        x_correct, y_correct, fake_x_correct, softmax3, fake_x_pre, f_fakeX, fake_x, fake_y = cycle_gan.model()

        # G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss,
        # x_correct, y_correct, fake_y_correct, softmax3, fake_y_pre, f_fakeX, fake_x, fake_y= cycle_gan.model()

        # softmax3,fake_y_pre,f_fakeX,fake_x,fake_y= cycle_gan.model()
        # optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss)
        # summary_op = tf.summary.merge_all()
        # train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = "checkpoints/20190611-1650/model.ckpt-30000.meta"
            print('meta_graph_path', meta_graph_path)
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, "checkpoints/20190611-1650/model.ckpt-30000")

            step = 0
            # meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            # restore = tf.train.import_meta_graph(meta_graph_path)
            # restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            # step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                result_dir = './result'
                fake_dir = os.path.join(result_dir, 'fake_xy')
                roc_dir = os.path.join(result_dir, 'roc_curves')
                plot_dir = os.path.join(result_dir, 'tsne_pca')
                conv_dir = os.path.join(result_dir, 'convs')
                occ_dir = os.path.join(result_dir, 'occ_test')
                Xconv_dir = os.path.join(result_dir, 'Xconv_dir')
                fakeXconv_dir = os.path.join(result_dir, 'fakeXconv_dir')
                Y_VGGconv_dir = os.path.join(result_dir, 'Y_VGGconv_dir')
                fakeY_VGGconv_dir = os.path.join(result_dir, 'fakeY_VGGconv_dir')

                rconv_dir = os.path.join(result_dir, 'resconvs')
                utils.prepare_dir(result_dir)
                utils.prepare_dir(occ_dir)
                utils.prepare_dir(fake_dir)
                utils.prepare_dir(roc_dir)
                utils.prepare_dir(plot_dir)
                utils.prepare_dir(conv_dir)
                utils.prepare_dir(rconv_dir)
                utils.prepare_dir(Xconv_dir)
                utils.prepare_dir(fakeXconv_dir)
                utils.prepare_dir(Y_VGGconv_dir)
                utils.prepare_dir(fakeY_VGGconv_dir)

                x_image, x_label, oximage = get_test_batch2("X", 1, FLAGS.image_size, FLAGS.image_size, "./dataset/")
                y_image, y_label, oyimage = get_test_batch2("Y", 1, FLAGS.image_size, FLAGS.image_size, "./dataset/")

                image = y_image[1]
                width = height = 256

                occluded_size = 16
                print('1----------------')
                # data  = NP()
                data = np.empty((width * height + 1, width, height, 3), dtype="float32")
                print('2----------------')
                print('data  ---')
                data[0, :, :, :] = image
                cnt = 1
                for i in range(height):
                    for j in range(width):
                        i_min = int(i - occluded_size / 2)
                        i_max = int(i + occluded_size / 2)
                        j_min = int(j - occluded_size / 2)
                        j_max = int(j + occluded_size / 2)
                        if i_min < 0:
                            i_min = 0
                        if i_max > height:
                            i_max = height
                        if j_min < 0:
                            j_min = 0
                        if j_max > width:
                            j_max = width
                        data[cnt, :, :, :] = image
                        data[cnt, i_min:i_max, j_min:j_max, :] = 255
                        # print(data[i].shape)
                        cnt += 1
                #
                # [idx_u]=np.where(np.max(Uy[id_y]))

                # [idx_u]=np.where(np.max(Uy))

                u_ys = np.empty([data.shape[0]], dtype='float64')
                occ_map = np.empty((width, height), dtype='float64')

                print('occ_map.shape', occ_map.shape)
                cnt = 0
                feature_y_eval = sess.run(
                    softmax3,
                    feed_dict={cycle_gan.y: [data[0]]})  #

                # print('softmax3',feature_y_eval.eval())

                u_y0 = feature_y_eval[0]
                [idx_u] = np.where(np.max(u_y0))
                idx_u = idx_u[0]
                print('feature_y_eval', feature_y_eval)
                print('u_y0', u_y0)
                max = 0
                print('len u_y0', len(u_y0))
                for val in range(len(u_y0)):
                    vallist = u_y0[val]
                    if vallist > max:
                        max = vallist

                u_y0 = max
                # print('max', u_y0[idx_u])
                print('max', u_y0)
                # print('u_y01',u_y0[idx_u])

                for i in range(width):
                    for j in range(height):
                        feature_y_eval = sess.run(
                            softmax3,
                            feed_dict={cycle_gan.y: [data[cnt + 1]]})
                        u_y = feature_y_eval[0]
                        # u_y =  max(u_y)
                        print('u_y', u_y)
                        u_y1 = 0
                        for val in range(len(u_y)):
                            vallist = u_y[val]
                            if vallist > u_y1:
                                u_y1 = vallist

                        occ_value = u_y0 - u_y1
                        occ_map[i, j] = occ_value
                        print(str(cnt) + ':' + str(occ_value))
                        cnt += 1

                occ_map_path = os.path.join(occ_dir, 'occlusion_map_{}.txt'.format('1'))
                np.savetxt(occ_map_path, occ_map, fmt='%0.8f')
                cv2.imwrite(os.path.join(occ_dir, '{}.png'.format('1')), oyimage[1])
                draw_heatmap(occ_map_path=occ_map_path, ori_img=oyimage[1],
                             save_dir=os.path.join(occ_dir, 'heatmap_{}.png'.format('1')))








        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:

            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
Пример #14
0
def train_ladder(cli_params, dataset=None, save_to='results/ova_all_full'):
    cli_params['save_dir'] = prepare_dir(save_to)
    logfile = os.path.join(cli_params['save_dir'], 'log.txt')

    # Log also DEBUG to a file
    fh = logging.FileHandler(filename=logfile)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    logger.info('Logging into %s' % logfile)

    p, loaded = load_and_log_params(cli_params)

    ladder = setup_model(p)

    # Training
    all_params = ComputationGraph([ladder.costs.total]).parameters
    logger.info('Found the following parameters: %s' % str(all_params))

    # Fetch all batch normalization updates. They are in the clean path.
    bn_updates = ComputationGraph([ladder.costs.class_clean]).updates
    assert 'counter' in [u.name for u in bn_updates.keys()], \
        'No batch norm params in graph - the graph has been cut?'

    training_algorithm = GradientDescent(
        cost=ladder.costs.total,
        params=all_params,
        step_rule=Adam(learning_rate=ladder.lr))
    # In addition to actual training, also do BN variable approximations
    training_algorithm.add_updates(bn_updates)

    short_prints = {
        "train": {
            'T_C_class': ladder.costs.class_corr,
            'T_C_de': ladder.costs.denois.values(),
        },
        "valid_approx":
        OrderedDict([
            ('V_C_class', ladder.costs.class_clean),
            ('V_E', ladder.error.clean),
            ('V_C_de', ladder.costs.denois.values()),
        ]),
        "valid_final":
        OrderedDict([
            ('VF_C_class', ladder.costs.class_clean),
            ('VF_E', ladder.error.clean),
            ('VF_C_de', ladder.costs.denois.values()),
        ]),
    }

    ovadataset = dataset['ovadataset']
    train_indexes = dataset['train_indexes']
    val_indexes = dataset['val_indexes']

    main_loop = MainLoop(
        training_algorithm,
        # Datastream used for training
        make_datastream(ovadataset,
                        train_indexes,
                        p.batch_size,
                        scheme=ShuffledScheme),
        model=Model(ladder.costs.total),
        extensions=[
            FinishAfter(after_n_epochs=p.num_epochs),

            # This will estimate the validation error using
            # running average estimates of the batch normalization
            # parameters, mean and variance
            ApproxTestMonitoring(
                [ladder.costs.class_clean, ladder.error.clean] +
                ladder.costs.denois.values(),
                make_datastream(ovadataset, val_indexes, p.batch_size),
                prefix="valid_approx"),

            # This Monitor is slower, but more accurate since it will first
            # estimate batch normalization parameters from training data and
            # then do another pass to calculate the validation error.
            FinalTestMonitoring(
                [ladder.costs.class_clean, ladder.error.clean_mc] +
                ladder.costs.denois.values(),
                make_datastream(ovadataset, train_indexes, p.batch_size),
                make_datastream(ovadataset, val_indexes, p.batch_size),
                prefix="valid_final",
                after_n_epochs=p.num_epochs),
            TrainingDataMonitoring([
                ladder.costs.total, ladder.costs.class_corr,
                training_algorithm.total_gradient_norm
            ] + ladder.costs.denois.values(),
                                   prefix="train",
                                   after_epoch=True),
            ShortPrinting(short_prints),
            LRDecay(ladder.lr,
                    p.num_epochs * p.lrate_decay,
                    p.num_epochs,
                    after_epoch=True),
        ])
    main_loop.run()

    # Get results
    df = main_loop.log.to_dataframe()
    col = 'valid_final_error_matrix_cost'
    logger.info('%s %g' % (col, df[col].iloc[-1]))

    ds = make_datastream(ovadataset, val_indexes, p.batch_size)
    outputs = ladder.act.clean.labeled.h[len(ladder.layers) - 1]
    outputreplacer = TestMonitoring()
    _, _, outputs = outputreplacer._get_bn_params(outputs)

    cg = ComputationGraph(outputs)
    f = cg.get_theano_function()

    it = ds.get_epoch_iterator(as_dict=True)
    res = []
    inputs = {
        'features_labeled': [],
        'targets_labeled': [],
        'features_unlabeled': []
    }
    # Loop over one epoch
    for d in it:
        # Store all inputs
        for k, v in d.iteritems():
            inputs[k] += [v]
        # Store outputs
        res += [f(*[d[str(inp)] for inp in cg.inputs])]

    # Concatenate all minibatches
    res = [numpy.vstack(minibatches) for minibatches in zip(*res)]
    inputs = {k: numpy.vstack(v) for k, v in inputs.iteritems()}

    if main_loop.log.status['epoch_interrupt_received']:
        return None
    return res[0], inputs
import parameters as par
import utils
import pandas as pd

# Set filename of csv file containing the search results
FILENAME = 'results.csv'

VAR1 = par.hyperparameter1_search
VAR2 = par.hyperparameter2_search

# Overwrite the results file if it already exists
if os.path.isfile(f"./{FILENAME}"):
    os.remove(f"./{FILENAME}")

# Empty the folder where the prediction plots will be stored
utils.prepare_dir(par.pred_dir, empty=True)


def write_to_csv(row):
    """Write a text row to a csv file."""
    with open(FILENAME, 'a', newline='') as file:
        writer = csv.writer(file,
                            delimiter=',',
                            quotechar='|',
                            quoting=csv.QUOTE_MINIMAL)
        writer.writerow(row)


# Write the column titles to the results file.
write_to_csv([VAR1['Name'], VAR2['Name'], 'Accuracy'])
Пример #16
0
from utils import Accumulator
import utils
import graph_builder
import mixer

if __name__ == '__main__':
    print(' '.join(sys.argv))

    parser = utils.get_argparser()
    parser.add_argument('--n-skip',
                        default=1,
                        type=int,
                        help='Number of frames to skip')
    args = parser.parse_args()
    print(args)
    utils.prepare_dir(args)
    utils.print_host_info()

    tf.get_variable_scope()._reuse = None

    _seed = args.base_seed + args.add_seed
    tf.set_random_seed(_seed)
    np.random.seed(_seed)

    tg, test_graph = graph_builder.build_graph_subsample(args)
    tvars = tf.trainable_variables()
    print([tvar.name for tvar in tvars])
    print("Model size: {:.2f}M".format(utils.get_model_size(tvars)))

    tg_ml_cost = tf.reduce_mean(tg.ml_cost)
    global_step = tf.Variable(0, trainable=False, name="global_step")
Пример #17
0
def main():
    args = parser.parse_args()
    env_name = args.env_name
    input_file = args.input_file
    checkpoint_file = args.resume
    test_only = args.test_only
    seed = args.seed
    no_gpu = args.no_gpu
    dir_name = args.dir_name
    visualize = args.visualize
    n_test_steps = args.n_test_steps
    log_perf_file = args.log_perf_file
    min_distance = args.min_distance
    max_distance = args.max_distance
    threshold = args.threshold
    y_range = args.y_range
    n_training_samples = args.n_training_samples
    start_index = args.start_index
    exp_name = args.exp_name
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    n_epochs = args.n_epochs

    # Specific to Humanoid - Pybullet
    if visualize and env_name == 'HumanoidBulletEnv-v0':
        spec = gym.envs.registry.env_specs[env_name]
        class_ = gym.envs.registration.load(spec._entry_point)
        env = class_(**{**spec._kwargs}, **{'render': True})
    else:
        env = gym.make(env_name)

    set_global_seed(seed)
    env.seed(seed)

    input_shape = env.observation_space.shape[0] + 3
    output_shape = env.action_space.shape[0]
    net = Policy(input_shape, output_shape)
    if not no_gpu:
        net = net.cuda()
    optimizer = Adam(net.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()
    epochs = 0

    if checkpoint_file:
        epochs, net, optimizer = load_checkpoint(checkpoint_file, net,
                                                 optimizer)

    if not checkpoint_file and test_only:
        print('ERROR: You have not entered a checkpoint file.')
        return

    if not test_only:
        if not os.path.isfile(input_file):
            raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
                                    input_file)

        training_file = open(input_file, 'rb')
        old_states = []
        norms = []
        goals = []
        actions = []
        n_samples = -1

        while n_samples - start_index < n_training_samples:
            try:
                old_s, old_g, new_s, new_g, action = pickle.load(training_file)
                n_samples += 1

                if n_samples < start_index:
                    continue

                old_states.append(np.squeeze(np.array(old_s)))
                norms.append(
                    find_norm(np.squeeze(np.array(new_g) - np.array(old_g))))
                goals.append(
                    preprocess_goal(
                        np.squeeze(np.array(new_g) - np.array(old_g))))
                actions.append(np.squeeze(np.array(action)))
            except (EOFError, ValueError):
                break

        old_states = np.array(old_states)
        norms = np.array(norms)
        goals = np.array(goals)
        actions = np.array(actions)

        normalization_factors = {
            'state': [old_states.mean(axis=0),
                      old_states.std(axis=0)],
            'distance_per_step': [norms.mean(axis=0),
                                  norms.std(axis=0)]
        }
        n_file = open(env_name + '_normalization_factors.pkl', 'wb')
        pickle.dump(normalization_factors, n_file)
        n_file.close()

        old_states = normalize(old_states,
                               env_name + '_normalization_factors.pkl',
                               'state')

        # Summary writer for tensorboardX
        writer = {}
        writer['writer'] = SummaryWriter()

        # Split data into training and validation
        indices = np.arange(old_states.shape[0])
        shuffle(indices)
        val_data = np.concatenate(
            (old_states[indices[:int(old_states.shape[0] / 5)]],
             goals[indices[:int(old_states.shape[0] / 5)]]),
            axis=1)
        val_labels = actions[indices[:int(old_states.shape[0] / 5)]]
        training_data = np.concatenate(
            (old_states[indices[int(old_states.shape[0] / 5):]],
             goals[indices[int(old_states.shape[0] / 5):]]),
            axis=1)
        training_labels = actions[indices[int(old_states.shape[0] / 5):]]
        del old_states, norms, goals, actions, indices

        checkpoint_dir = os.path.join(env_name, 'naive_gcp_checkpoints')
        if dir_name:
            checkpoint_dir = os.path.join(checkpoint_dir, dir_name)
        prepare_dir(checkpoint_dir)

        for e in range(epochs, n_epochs):
            ep_loss = []
            # Train network
            for i in range(int(len(training_data) / batch_size) + 1):
                inp = training_data[batch_size * i:batch_size * (i + 1)]
                out = net(
                    convert_to_variable(inp, grad=False, gpu=(not no_gpu)))
                target = training_labels[batch_size * i:batch_size * (i + 1)]
                target = convert_to_variable(np.array(target),
                                             grad=False,
                                             gpu=(not no_gpu))
                loss = criterion(out, target)
                optimizer.zero_grad()
                ep_loss.append(loss.item())
                loss.backward()
                optimizer.step()

            # Validation
            val_loss = []
            for i in range(int(len(val_data) / batch_size) + 1):
                inp = val_data[batch_size * i:batch_size * (i + 1)]
                out = net(
                    convert_to_variable(inp, grad=False, gpu=(not no_gpu)))
                target = val_labels[batch_size * i:batch_size * (i + 1)]
                target = convert_to_variable(np.array(target),
                                             grad=False,
                                             gpu=(not no_gpu))
                loss = criterion(out, target)
                val_loss.append(loss.item())

            writer['iter'] = e + 1
            writer['writer'].add_scalar('data/val_loss',
                                        np.array(val_loss).mean(), e + 1)
            writer['writer'].add_scalar('data/training_loss',
                                        np.array(ep_loss).mean(), e + 1)

            save_checkpoint(
                {
                    'epochs': (e + 1),
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict()
                },
                filename=os.path.join(checkpoint_dir,
                                      str(e + 1) + '.pth.tar'))

            print('Epoch:', e + 1)
            print('Training loss:', np.array(ep_loss).mean())
            print('Val loss:', np.array(val_loss).mean())
            print('')

    # Now we use the trained net to see how the agent reaches a different
    # waypoint from the current one.

    success = 0
    failure = 0

    closest_distances = []
    time_to_closest_distances = []

    f = open(env_name + '_normalization_factors.pkl', 'rb')
    normalization_factors = pickle.load(f)
    average_distance = normalization_factors['distance_per_step'][0]

    for i in range(n_test_steps):
        state = env.reset()
        if env_name == 'Ant-v2':
            obs = env.unwrapped.get_body_com('torso')
            target_obs = [
                obs[0] + np.random.uniform(min_distance, max_distance),
                obs[1] + np.random.uniform(-y_range, y_range), obs[2]
            ]
            target_obs = rotate_point(target_obs, env.unwrapped.angle)
            env.unwrapped.sim.model.body_pos[-1] = target_obs
        elif env_name == 'MinitaurBulletEnv-v0':
            obs = env.unwrapped.get_minitaur_position()
            target_obs = [
                obs[0] + np.random.uniform(min_distance, max_distance),
                obs[1] + np.random.uniform(-y_range, y_range), obs[2]
            ]
            target_obs = rotate_point(
                target_obs, env.unwrapped.get_minitaur_rotation_angle())
            env.unwrapped.set_target_position(target_obs)
        elif env_name == 'HumanoidBulletEnv-v0':
            obs = env.unwrapped.robot.get_robot_position()
            target_obs = [
                obs[0] + np.random.uniform(min_distance, max_distance),
                obs[1] + np.random.uniform(-y_range, y_range), obs[2]
            ]
            target_obs = rotate_point(target_obs, env.unwrapped.robot.yaw)
            env.unwrapped.robot.set_target_position(target_obs[0],
                                                    target_obs[1])
        steps = 0
        done = False
        closest_d = distance(obs, target_obs)
        closest_t = 0
        while distance(obs, target_obs) > threshold and not done:
            goal = preprocess_goal(target_obs - obs)
            state = normalize(np.array(state),
                              env_name + '_normalization_factors.pkl')
            inp = np.concatenate([np.squeeze(state), goal])
            inp = convert_to_variable(inp, grad=False, gpu=(not no_gpu))
            action = net(inp).cpu().detach().numpy()
            state, _, done, _ = env.step(action)
            steps += 1
            if env_name == 'MinitaurBulletEnv-v0':
                obs = env.unwrapped.get_minitaur_position()
            elif env_name == 'HumanoidBulletEnv-v0':
                obs = env.unwrapped.robot.get_robot_position()
            if distance(obs, target_obs) < closest_d:
                closest_d = distance(obs, target_obs)
                closest_t = steps
            if visualize:
                env.render()

        if distance(obs, target_obs) <= threshold:
            success += 1
        elif done:
            failure += 1

        if visualize:
            time.sleep(2)

        closest_distances.append(closest_d)
        time_to_closest_distances.append(closest_t)

    print('Successes: %d, Failures: %d, '
          'Closest distance: %f, Time to closest distance: %d' %
          (success, failure, np.mean(closest_distances),
           np.mean(time_to_closest_distances)))

    if log_perf_file:
        f = open(log_perf_file, 'a+')
        f.write(exp_name + ':Seed-' + str(seed) + ',Success-' + str(success) +
                ',Failure-' + str(failure) + ',Closest_distance-' +
                str(closest_distances) + ',Time_to_closest_distance-' +
                str(time_to_closest_distances) + '\n')
        f.close()
Пример #18
0
def train(cli_params):
    fn = 'noname'
    if 'save_to' in nodefaultargs or not cli_params.get('load_from'):
        fn = cli_params['save_to']
    cli_params['save_dir'] = prepare_dir(fn)
    nodefaultargs.append('save_dir')

    logfile = os.path.join(cli_params['save_dir'], 'log.txt')

    # Log also DEBUG to a file
    fh = logging.FileHandler(filename=logfile)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    logger.info('Logging into %s' % logfile)

    p, loaded = load_and_log_params(cli_params)

    in_dim, data, whiten, cnorm = setup_data(p, test_set=False)
    if not loaded:
        # Set the zero layer to match input dimensions
        p.encoder_layers = (in_dim, ) + p.encoder_layers

    ladder = setup_model(p)

    # Training
    all_params = ComputationGraph([ladder.costs.total]).parameters
    logger.info('Found the following parameters: %s' % str(all_params))

    # Fetch all batch normalization updates. They are in the clean path.
    # you can turn off BN by setting is_normalizing = False in ladder.py
    bn_updates = ComputationGraph([ladder.costs.class_clean]).updates
    assert not bn_updates or 'counter' in [u.name for u in bn_updates.keys()], \
        'No batch norm params in graph - the graph has been cut?'

    training_algorithm = GradientDescent(
        cost=ladder.costs.total,
        parameters=all_params,
        step_rule=Adam(learning_rate=ladder.lr))
    # In addition to actual training, also do BN variable approximations
    if bn_updates:
        training_algorithm.add_updates(bn_updates)

    short_prints = {
        "train":
        OrderedDict([
            ('T_E', ladder.error.clean),
            ('T_O', ladder.oos.clean),
            ('T_C_class', ladder.costs.class_corr),
            ('T_C_de', ladder.costs.denois.values()),
            ('T_T', ladder.costs.total),
        ]),
        "valid_approx":
        OrderedDict([
            ('V_C_class', ladder.costs.class_clean),
            ('V_E', ladder.error.clean),
            ('V_O', ladder.oos.clean),
            ('V_C_de', ladder.costs.denois.values()),
            ('V_T', ladder.costs.total),
        ]),
        "valid_final":
        OrderedDict([
            ('VF_C_class', ladder.costs.class_clean),
            ('VF_E', ladder.error.clean),
            ('VF_O', ladder.oos.clean),
            ('VF_C_de', ladder.costs.denois.values()),
            ('V_T', ladder.costs.total),
        ]),
    }

    if len(data.valid_ind):
        main_loop = MainLoop(
            training_algorithm,
            # Datastream used for training
            make_datastream(data.train,
                            data.train_ind,
                            p.batch_size,
                            n_labeled=p.labeled_samples,
                            n_unlabeled=p.unlabeled_samples,
                            whiten=whiten,
                            cnorm=cnorm,
                            balanced_classes=p.balanced_classes,
                            dseed=p.dseed),
            model=Model(ladder.costs.total),
            extensions=[
                FinishAfter(after_n_epochs=p.num_epochs),

                # This will estimate the validation error using
                # running average estimates of the batch normalization
                # parameters, mean and variance
                ApproxTestMonitoring([
                    ladder.costs.class_clean, ladder.error.clean,
                    ladder.oos.clean, ladder.costs.total
                ] + ladder.costs.denois.values(),
                                     make_datastream(
                                         data.valid,
                                         data.valid_ind,
                                         p.valid_batch_size,
                                         whiten=whiten,
                                         cnorm=cnorm,
                                         balanced_classes=p.balanced_classes,
                                         scheme=ShuffledScheme),
                                     prefix="valid_approx"),

                # This Monitor is slower, but more accurate since it will first
                # estimate batch normalization parameters from training data and
                # then do another pass to calculate the validation error.
                FinalTestMonitoring(
                    [
                        ladder.costs.class_clean, ladder.error.clean,
                        ladder.oos.clean, ladder.costs.total
                    ] + ladder.costs.denois.values(),
                    make_datastream(data.train,
                                    data.train_ind,
                                    p.batch_size,
                                    n_labeled=p.labeled_samples,
                                    whiten=whiten,
                                    cnorm=cnorm,
                                    balanced_classes=p.balanced_classes,
                                    scheme=ShuffledScheme),
                    make_datastream(data.valid,
                                    data.valid_ind,
                                    p.valid_batch_size,
                                    n_labeled=len(data.valid_ind),
                                    whiten=whiten,
                                    cnorm=cnorm,
                                    balanced_classes=p.balanced_classes,
                                    scheme=ShuffledScheme),
                    prefix="valid_final",
                    after_n_epochs=p.num_epochs,
                    after_training=True),
                TrainingDataMonitoring([
                    ladder.error.clean, ladder.oos.clean, ladder.costs.total,
                    ladder.costs.class_corr,
                    training_algorithm.total_gradient_norm
                ] + ladder.costs.denois.values(),
                                       prefix="train",
                                       after_epoch=True),
                # ladder.costs.class_clean - save model whenever we have best validation result another option `('train',ladder.costs.total)`
                SaveParams(('valid_approx', ladder.error.clean),
                           all_params,
                           p.save_dir,
                           after_epoch=True),
                SaveExpParams(p, p.save_dir, before_training=True),
                SaveLog(p.save_dir, after_training=True),
                ShortPrinting(short_prints),
                LRDecay(ladder.lr,
                        p.num_epochs * p.lrate_decay,
                        p.num_epochs,
                        lrmin=p.lrmin,
                        after_epoch=True),
            ])
    else:
        main_loop = MainLoop(
            training_algorithm,
            # Datastream used for training
            make_datastream(data.train,
                            data.train_ind,
                            p.batch_size,
                            n_labeled=p.labeled_samples,
                            n_unlabeled=p.unlabeled_samples,
                            whiten=whiten,
                            cnorm=cnorm,
                            balanced_classes=p.balanced_classes,
                            dseed=p.dseed),
            model=Model(ladder.costs.total),
            extensions=[
                FinishAfter(after_n_epochs=p.num_epochs),
                TrainingDataMonitoring([
                    ladder.error.clean, ladder.oos.clean, ladder.costs.total,
                    ladder.costs.class_corr,
                    training_algorithm.total_gradient_norm
                ] + ladder.costs.denois.values(),
                                       prefix="train",
                                       after_epoch=True),
                # ladder.costs.class_clean - save model whenever we have best validation result another option `('train',ladder.costs.total)`
                SaveParams(('train', ladder.error.clean),
                           all_params,
                           p.save_dir,
                           after_epoch=True),
                SaveExpParams(p, p.save_dir, before_training=True),
                SaveLog(p.save_dir, after_training=True),
                ShortPrinting(short_prints),
                LRDecay(ladder.lr,
                        p.num_epochs * p.lrate_decay,
                        p.num_epochs,
                        lrmin=p.lrmin,
                        after_epoch=True),
            ])
    main_loop.run()

    # Get results
    if len(data.valid_ind) == 0:
        return None

    df = DataFrame.from_dict(main_loop.log, orient='index')
    col = 'valid_final_error_rate_clean'
    logger.info('%s %g' % (col, df[col].iloc[-1]))

    if main_loop.log.status['epoch_interrupt_received']:
        return None
    return df
Пример #19
0
def ansible_restore(cmds):

    if not (bool(cmds.path) ^ bool(cmds.s3)):
        raise Exception('Only one of --path or --s3 must be specified')

    if not cmds.nodes:
        config = ConfigParser()
        if len(config.read('config.ini')) == 0:
            raise Exception(
                'ERROR: Cannot find config.ini in script directory')
        nodes = re.findall('[^,\s\[\]]+', config.get('cassandra-info',
                                                     'hosts'))
        if not nodes:
            raise Exception('Hosts argument in config.ini not specified')
    else:
        nodes = cmds.nodes

    # prepare working directories
    temp_path = sys.path[0] + '/.temp'
    prepare_dir(sys.path[0] + '/output_logs', output=True)
    prepare_dir(temp_path, output=True)

    if cmds.path:
        zip_path = cmds.path
    elif cmds.s3:
        s3 = s3_bucket()
        s3_snapshots = s3_list_snapshots(s3)

        if cmds.s3 == True:  # not a string parameter
            if len(s3_snapshots) == 0:
                print('No snapshots found in s3')
                exit(0)

            # search
            print('\nSnapshots found:')
            template = '{0:5} | {1:67}'
            print(template.format('Index', 'Snapshot'))
            for idx, snap in enumerate(s3_snapshots):
                # every snapshot starts with cassandra-snapshot- (19 chars)
                stripped = snap[19:]
                print(template.format(idx + 1, stripped))

            index = 0
            while index not in range(1, len(s3_snapshots) + 1):
                try:
                    index = int(raw_input('Enter snapshot index: '))
                except ValueError:
                    continue
            s3_key = s3_snapshots[index - 1]

        else:
            s3_key = cmds.s3
            if not s3_key.startswith('cassandra-snapshot-'):
                s3_key = 'cassandra-snapshot-' + s3_key

            if s3_key not in s3_snapshots:
                raise Exception('S3 Snapshot not found')

        print('Retrieving snapshot from S3: %s' % s3_key)
        s3.download_file(s3_key, temp_path + '/temp.zip')
        zip_path = temp_path + '/temp.zip'
    else:
        raise Exception('No file specified.')

    # unzip
    print('Unzipping snapshot file')
    z = zipfile.ZipFile(zip_path, 'r', allowZip64=True)
    z.extractall(temp_path)

    # check schema specification args
    print('Checking arguments . . .')
    restore_command = 'restore.py '
    load_schema_command = 'load_schema.py '
    if cmds.keyspace:

        schema = get_zipped_schema(temp_path + '/schemas.zip')
        for keyspace in cmds.keyspace:
            if keyspace not in schema.keys():
                raise Exception('ERROR: Keyspace "%s" not in snapshot schema' %
                                keyspace)

        keyspace_arg = '-ks ' + ' '.join(cmds.keyspace)
        restore_command += keyspace_arg
        load_schema_command += keyspace_arg

        if cmds.table:

            if len(cmds.keyspace) != 1:
                raise Exception(
                    'ERROR: One keyspace must be specified with table argument'
                )

            ks = cmds.keyspace[0]
            for tb in cmds.table:
                if tb not in schema[ks]:
                    raise Exception(
                        'ERROR: Table "%s" not found in keyspace "%s"' %
                        (tb, ks))

            restore_command += ' -tb ' + ' '.join(cmds.table)

    elif cmds.table:
        raise Exception('ERROR: Keyspace must be specified with tables')

    playbook_args = {
        'nodes': ' '.join(nodes),
        'restore_command': restore_command,
        'load_schema_command': load_schema_command,
        'reload': cmds.reload,
        'hard_reset': cmds.hard_reset
    }
    return_code = run_playbook('restore.yml', playbook_args)

    if return_code != 0:
        print('ERROR: Ansible script failed to run properly. ' +
              'If this persists, try --hard-reset. (TODO)')  # TODO
    else:
        print('Process complete.')
        print('Output logs saved in %s' % (sys.path[0] + '/output_logs'))
Пример #20
0
def train(cli_params):
    cli_params['save_dir'] = prepare_dir(cli_params['save_to'])
    logfile = os.path.join(cli_params['save_dir'], 'log.txt')

    # Log also DEBUG to a file
    fh = logging.FileHandler(filename=logfile)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    logger.info('Logging into %s' % logfile)

    p, loaded = load_and_log_params(cli_params)
    in_dim, data, whiten, cnorm = setup_data(p, test_set=False)
    if not loaded:
        # Set the zero layer to match input dimensions
        p.encoder_layers = (in_dim,) + p.encoder_layers

    ladder = setup_model(p)

    # Training
    all_params = ComputationGraph([ladder.costs.total]).parameters
    logger.info('Found the following parameters: %s' % str(all_params))

    # Fetch all batch normalization updates. They are in the clean path.
    bn_updates = ComputationGraph([ladder.costs.class_clean]).updates
    assert 'counter' in [u.name for u in bn_updates.keys()], \
        'No batch norm params in graph - the graph has been cut?'

    training_algorithm = GradientDescent(
        cost=ladder.costs.total, parameters=all_params,
        step_rule=Adam(learning_rate=ladder.lr.get_value()))
    # In addition to actual training, also do BN variable approximations
    training_algorithm.add_updates(bn_updates)

    short_prints = {
        "train": {
            'T_C_class': ladder.costs.class_corr,
            'T_C_de': ladder.costs.denois.values(),
        },
        "valid_approx": OrderedDict([
            ('V_C_class', ladder.costs.class_clean),
            ('V_E', ladder.error.clean),
            ('V_C_de', ladder.costs.denois.values()),
        ]),
        "valid_final": OrderedDict([
            ('VF_C_class', ladder.costs.class_clean),
            ('VF_E', ladder.error.clean),
            ('VF_C_de', ladder.costs.denois.values()),
        ]),
    }

    main_loop = MainLoop(
        training_algorithm,
        # Datastream used for training
        make_datastream(data.train, data.train_ind,
                        p.batch_size,
                        n_labeled=p.labeled_samples,
                        n_unlabeled=p.unlabeled_samples,
                        whiten=whiten,
                        cnorm=cnorm),
        model=Model(ladder.costs.total),
        extensions=[
            FinishAfter(after_n_epochs=p.num_epochs),

            # This will estimate the validation error using
            # running average estimates of the batch normalization
            # parameters, mean and variance
            ApproxTestMonitoring(
                [ladder.costs.class_clean, ladder.error.clean]
                + ladder.costs.denois.values(),
                make_datastream(data.valid, data.valid_ind,
                                p.valid_batch_size, whiten=whiten, cnorm=cnorm,
                                scheme=ShuffledScheme),
                prefix="valid_approx"),

            # This Monitor is slower, but more accurate since it will first
            # estimate batch normalization parameters from training data and
            # then do another pass to calculate the validation error.
            FinalTestMonitoring(
                [ladder.costs.class_clean, ladder.error.clean]
                + ladder.costs.denois.values(),
                make_datastream(data.train, data.train_ind,
                                p.batch_size,
                                n_labeled=p.labeled_samples,
                                whiten=whiten, cnorm=cnorm,
                                scheme=ShuffledScheme),
                make_datastream(data.valid, data.valid_ind,
                                p.valid_batch_size,
                                n_labeled=len(data.valid_ind),
                                whiten=whiten, cnorm=cnorm,
                                scheme=ShuffledScheme),
                prefix="valid_final",
                after_n_epochs=p.num_epochs),

            TrainingDataMonitoring(
                [ladder.costs.total, ladder.costs.class_corr,
                 training_algorithm.total_gradient_norm]
                + ladder.costs.denois.values(),
                prefix="train", after_epoch=True),

            SaveParams(None, all_params, p.save_dir, after_epoch=True),
            SaveExpParams(p, p.save_dir, before_training=True),
            SaveLog(p.save_dir, after_training=True),
            ShortPrinting(short_prints),
            LRDecay(ladder.lr, p.num_epochs * p.lrate_decay, p.num_epochs,
                    after_epoch=True),
        ])
    main_loop.run()

    # Get results
    df = DataFrame.from_dict(main_loop.log, orient='index')
    col = 'valid_final_error_rate_clean'
    logger.info('%s %g' % (col, df[col].iloc[-1]))

    if main_loop.log.status['epoch_interrupt_received']:
        return None
    return df
Пример #21
0
        plt.title('original picture')
        orimg = (orimg - orimg.min()) / (orimg.max() - orimg.min()) * 255
        orimg = orimg.astype(np.uint8)
        #        orimg = orimg.astype(np.uint8)
        plt.imshow(orimg)
        for idx, layer in enumerate([0, 1, 2, 3, 4]):
            # for idx, layer in enumerate(vgg16_conv.conv_layer_indices):
            plt.subplot(2, 3, idx + 2)
            img, activation = vis_layer(layer, net, denet)
            plt.title(
                f'restruction from Conv_{layer}, the max activations is {activation}'
            )
            plt.imshow(img)

        # plt.show()
        utils.prepare_dir('./out_visualization', empty=False)
        plt.savefig('./out_visualization/restruction.jpg')
        # plt.savefig(os.path.join(PLOT_DIR, 'restruction.jpg'), bbox_inches='tight')
        # 显示第一层滤波器
        viewlayer = 0
        parm = {}  # filterviewer
        parmidx = {0: 0, 1: 16, 2: 34, 3: 52, 4: 70}
        for name, parameters in net.named_parameters():
            parm[name] = parameters.detach().cpu().numpy()
        parmdic = list(parm.keys())
        xx = parmdic[parmidx[viewlayer]]
        weight = parm[parmdic[parmidx[viewlayer]]]
        view_lib.plot_conv_weights(weight, 'conv{}'.format(viewlayer))

        # 显示特征图
        viewlayer = 0
Пример #22
0
def ansible_snapshot(cmds):

    # set title of snapshot file
    timestamp = str(time.time()).split('.')[0]
    if cmds.title:
        title = cmds.title
    else:
        title = timestamp

    if not cmds.nodes:
        config = ConfigParser()
        if len(config.read('config.ini')) == 0:
            raise Exception(
                'ERROR: Cannot find config.ini in script directory')
        nodes = re.findall('[^,\s\[\]]+', config.get('cassandra-info',
                                                     'hosts'))
        if not nodes:
            raise Exception('Hosts argument in config.ini not specified')
    else:
        nodes = cmds.nodes

    if cmds.s3:
        s3 = s3_bucket()  # checks config.ini args

    # path to save snapshot in
    if cmds.path:
        save_path = cmds.path
    else:
        save_path = sys.path[0] + '/snapshots'
        make_dir(save_path)

    if os.path.isfile(save_path + '/' + title + '.zip'):
        raise Exception('%s has already been created' % save_path + '/' +
                        title + '.zip')

    # prepare working directories
    temp_path = sys.path[0] + '/.temp'
    prepare_dir(sys.path[0] + '/output_logs')
    prepare_dir(temp_path)
    os.makedirs(temp_path + '/' + title)

    # check keyspace and table args
    snapshotter_command = 'snapshotter.py '
    save_schema_command = 'save_schema.py '
    if cmds.keyspace:

        keyspace_arg = '-ks ' + ' '.join(cmds.keyspace)
        snapshotter_command += keyspace_arg
        save_schema_command += keyspace_arg

        if cmds.table:
            if len(cmds.keyspace) != 1:
                raise Exception(
                    'ERROR: One keyspace must be specified with table argument'
                )
            snapshotter_command += ' -tb ' + ' '.join(cmds.table)

    elif cmds.table:
        raise Exception('ERROR: Keyspace must be specified with tables')

    playbook_args = {
        'nodes': ' '.join(nodes),
        'snapshotter_command': snapshotter_command,
        'save_schema_command': save_schema_command,
        'path': temp_path + '/' + title,
        'reload': cmds.reload
    }

    # call playbook
    return_code = run_playbook('snapshot.yml', playbook_args)

    if return_code != 0:
        shutil.rmtree(temp_path + '/' + title)
        print('Error running ansible script')
    else:
        zip_dir(temp_path + '/' + title, save_path, title)

        if cmds.s3:

            file_size = os.path.getsize(save_path + '/' + title + '.zip')
            if confirm('Snapshot size is %i bytes. Upload? [y/n] ' %
                       file_size):
                print('Uploading to s3 . . .')
                key = 'cassandra-snapshot-' + title
                upload = True
                if key in [obj.key for obj in s3.objects.all()]:
                    upload = confirm(('"%s" already exists in the S3 bucket.' %
                                      key) + 'Overwrite? [y/n]')
                if upload:
                    s3.upload_file(save_path + '/' + title + '.zip', key)
                    print('Uploaded with key "%s"' % key)
                else:
                    print('Skipping upload to s3 . . .')

        print('Process complete.')
        print('Output logs saved in %s' % (sys.path[0] + '/output_logs'))
        print('Snapshot zip saved in %s' % save_path)
Пример #23
0
def train():
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/20190611-1650"
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            #X_train_file=FLAGS.X,
            #Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf)
        G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, \
        x_correct, y_correct, fake_x_correct, softmax3, fake_x_pre, f_fakeX, fake_x, fake_y_ = cycle_gan.model()
        # G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss,
        # x_correct,y_correct,fake_y_correct,softmax,fake_y_pre,f_fakeX, fake_x, fake_y= cycle_gan.model()
        #optimizers = cycle_gan.optimize(student_loss)
        summary_op = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            #checkpoints_dir = r"C:\Users\Administrator\Desktop\pure_VGG\checkpoints\2018\model.ckpt-31800"
            meta_graph_path = "checkpoints/20190611-1650/model.ckpt-94000.meta"
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, "checkpoints/20190611-1650/model.ckpt-94000")
            step = 0
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        f_yroc = np.array(
            [[1.91536183e-05, 9.99966621e-01, 1.17612826e-05, 2.36588585e-06]])
        sum_label = np.array([[0, 1, 0, 0]])
        sum_y_pre3 = np.array([[0, 1, 0, 0]])
        num_classes = FLAGS.classes

        try:
            while not coord.should_stop():

                result_dir = './result'
                plot_dir = os.path.join(result_dir, 'tsne_pca')
                roc_dir = os.path.join(result_dir, 'auc_roc')
                utils.prepare_dir(plot_dir)
                utils.prepare_dir(roc_dir)

                # x_image, x_label, _ = get_test_batch1('X', 100, FLAGS.image_size, FLAGS.image_size,
                #                                                   "./dataset/")
                # y_image, y_label, _ = get_test_batch1('Y', 10, FLAGS.image_size, FLAGS.image_size,
                #                                       "./dataset/")
                y_image, y_label = get_roc_batch(FLAGS.image_size,
                                                 FLAGS.image_size,
                                                 "./dataset/Y")
                fake_x_pre4 = []
                sotfmaxour = []
                fake_x_correct_cout = 0
                length3 = len(y_label)
                print('length3', length3)
                features_d = []
                for i in range(length3):

                    # ximgs = []
                    # xlbs = []
                    # ximgs.append(x_image[i])
                    # xlbs.append(x_label[i])

                    yimgs = []
                    ylbs = []
                    yimgs.append(y_image[i])
                    ylbs.append(y_label[i])

                    softmax_fakex, fakex_pre, ffakeX, fake_x_correct_eval = (
                        sess.run(
                            [softmax3, fake_x_pre, f_fakeX, fake_x_correct],
                            feed_dict={
                                cycle_gan.y: yimgs,
                                cycle_gan.y_label: ylbs,
                                # cycle_gan.x: ximgs,
                                # cycle_gan.x_label: xlbs
                            }))
                    step += 1
                    features_d.append(ffakeX[0])
                    if fake_x_correct_eval:
                        fake_x_correct_cout = fake_x_correct_cout + 1

                    print('fake_x_correct_eval', fake_x_correct_eval)
                    print('-----------Step %d:-------------' % step)
                    # print('ylbs', ylbs)
                    # print('softmax_fakex',softmax_fakex[0])
                    sotfmaxour.append(softmax_fakex[0])
                    fakex_pre_zhi = fakex_pre[0]
                    fake_x_pre4.append(fakex_pre_zhi)

                # print('y-label',y_label)
                # print('fake_x_pre4', fake_x_pre4)
                # print('sotfmaxour',sotfmaxour)

                print('fake_x_accuracy: {}'.format(fake_x_correct_cout /
                                                   length3))
                one_hot = dense_to_one_hot(np.array(fake_x_pre4),
                                           num_classes=num_classes)
                sum_label = dense_to_one_hot(np.array(y_label),
                                             num_classes=num_classes)
                # print('one_hot', one_hot)
                #  accuarcy  f1-score ....
                jisuan(y_label, fake_x_pre4)
                print('sum_label.shape', sum_label.shape)
                print(np.array(sotfmaxour).shape)

                roc(sum_label, sotfmaxour, num_classes, roc_dir)

                # TSNE

                print('len features_d', len(features_d))
                tsne = TSNE(n_components=2,
                            learning_rate=4).fit_transform(features_d)
                pca = PCA().fit_transform(features_d)
                plt.figure(figsize=(12, 6))
                plt.subplot(121)
                plt.scatter(tsne[:, 0], tsne[:, 1], c=y_label)
                plt.subplot(122)
                plt.scatter(pca[:, 0], pca[:, 1], c=y_label)
                plt.colorbar()  # 使用这一句就可以分辨出,颜色对应的类了!神奇啊。

                plt.savefig(os.path.join(plot_dir, 'plot.pdf'))
                exit()

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:
            # save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
            # print("Model saved in file: %s" % save_path)
            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
Пример #24
0
 train_batch = torch.utils.data.DataLoader(train_data,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True,
                                           num_workers=0)
 test_batch = torch.utils.data.DataLoader(test_data,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=0)
 criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
 optimizer = torch.optim.SGD(net.parameters(),
                             lr=lr,
                             momentum=0.9,
                             weight_decay=5e-4)  # 权重衰减
 model = net.to(device)
 best_acc = 60
 utils.prepare_dir('./train_data', empty=False)
 with open("./train_data/acc.txt", "w") as f:
     with open("./train_data/log.txt", "w") as f2:
         for epoch in range(1, EPOCH):
             if epoch > 60:
                 optimizer = torch.optim.SGD(net.parameters(),
                                             lr=0.001,
                                             momentum=0.5,
                                             weight_decay=5e-4)  # 权重衰减
             starttime = time()
             print('\nEpoch: %d start:' % epoch)
             net.train()
             sum_loss = 0.0
             correct = 0.0
             num = 0.0
             for step, data in enumerate(train_batch, 0):
Пример #25
0
def ansible_snapshot(cmds):

    # set title of snapshot file
    timestamp = str(time.time()).split('.')[0]
    if cmds.title:
        title = cmds.title
    else:
        title = timestamp

    if not cmds.nodes:
        config = ConfigParser()
        if len(config.read('config.ini')) == 0:
            raise Exception('ERROR: Cannot find config.ini in script directory')
        nodes = re.findall('[^,\s\[\]]+', config.get('cassandra-info', 'hosts'))
        if not nodes:
            raise Exception('Hosts argument in config.ini not specified')
    else:
        nodes = cmds.nodes

    if cmds.s3:
        s3 = s3_bucket() # checks config.ini args

    # path to save snapshot in
    if cmds.path:
        save_path = cmds.path
    else:
        save_path = sys.path[0] + '/snapshots'
        make_dir(save_path)
    
    if os.path.isfile(save_path + '/' + title + '.zip'):
        raise Exception('%s has already been created' %
                        save_path + '/' + title + '.zip')

    # prepare working directories
    temp_path = sys.path[0] + '/.temp'
    prepare_dir(sys.path[0] + '/output_logs')
    prepare_dir(temp_path)
    os.makedirs(temp_path + '/' + title)

    # check keyspace and table args
    snapshotter_command = 'snapshotter.py '
    save_schema_command = 'save_schema.py '
    if cmds.keyspace:

        keyspace_arg = '-ks ' + ' '.join(cmds.keyspace)
        snapshotter_command += keyspace_arg
        save_schema_command += keyspace_arg

        if cmds.table:
            if len(cmds.keyspace) != 1:
                raise Exception('ERROR: One keyspace must be specified with table argument')
            snapshotter_command += ' -tb ' + ' '.join(cmds.table)

    elif cmds.table:
        raise Exception('ERROR: Keyspace must be specified with tables')

    playbook_args = {
        'nodes' : ' '.join(nodes),
        'snapshotter_command' : snapshotter_command,
        'save_schema_command' : save_schema_command,
        'path' : temp_path + '/' + title,
        'reload' : cmds.reload
    }

    # call playbook
    return_code = run_playbook('snapshot.yml', playbook_args)

    if return_code != 0:
        shutil.rmtree(temp_path + '/' + title)
        print('Error running ansible script')
    else:
        zip_dir(temp_path + '/' + title, save_path, title)

        if cmds.s3:
        
            file_size = os.path.getsize(save_path + '/' + title + '.zip')
            if confirm('Snapshot size is %i bytes. Upload? [y/n] ' % file_size):
                print('Uploading to s3 . . .')
                key = 'cassandra-snapshot-' + title
                upload = True
                if key in [obj.key for obj in s3.objects.all()]:
                    upload = confirm(('"%s" already exists in the S3 bucket.' % key) +
                                      'Overwrite? [y/n]')
                if upload:
                    s3.upload_file(save_path + '/' + title + '.zip', key)
                    print('Uploaded with key "%s"' % key)
                else:
                    print('Skipping upload to s3 . . .')

        print('Process complete.')
        print('Output logs saved in %s' % (sys.path[0] + '/output_logs'))
        print('Snapshot zip saved in %s' % save_path)
Пример #26
0
def train():
    
    if FLAGS.load_model is not None:
        checkpoints_dir = "checkpoints/" + FLAGS.load_model.lstrip("checkpoints/")
    else:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        checkpoints_dir = "checkpoints/{}".format(current_time)
        try:
            os.makedirs(checkpoints_dir)
        except os.error:
            pass

    graph = tf.Graph()
    with graph.as_default():
        cycle_gan = CycleGAN(
            #X_train_file=FLAGS.X,
            #Y_train_file=FLAGS.Y,
            batch_size=FLAGS.batch_size,
            image_size=FLAGS.image_size,
            use_lsgan=FLAGS.use_lsgan,
            norm=FLAGS.norm,
            lambda1=FLAGS.lambda1,
            lambda2=FLAGS.lambda2,
            learning_rate=FLAGS.learning_rate,
            beta1=FLAGS.beta1,
            ngf=FLAGS.ngf
        )
        #G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, x_correct, y_correct, fake_y_correct, fake_y_pre
        G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss, \
        x_correct, y_correct, fake_x_correct, softmax3, fake_x_pre, f_fakeX, fake_x, fake_y= cycle_gan.model()




        # G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss,
        # x_correct, y_correct, fake_y_correct, softmax3, fake_y_pre, f_fakeX, fake_x, fake_y= cycle_gan.model()

        # softmax3,fake_y_pre,f_fakeX,fake_x,fake_y= cycle_gan.model()
        # optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss, teacher_loss, student_loss, learning_loss)
        # summary_op = tf.summary.merge_all()
        # train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        if FLAGS.load_model is not None:
            checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
            meta_graph_path = "checkpoints/20190611-1650/model.ckpt-94000.meta"
            print('meta_graph_path', meta_graph_path)
            restore = tf.train.import_meta_graph(meta_graph_path)
            restore.restore(sess, "checkpoints/20190611-1650/model.ckpt-94000")

            step = 0
            #meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
            #restore = tf.train.import_meta_graph(meta_graph_path)
            #restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
            #step = int(meta_graph_path.split("-")[2].split(".")[0])
        else:
            sess.run(tf.global_variables_initializer())
            step = 0

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                result_dir = './result'
                fake_dir = os.path.join(result_dir, 'fake_xy')
                roc_dir = os.path.join(result_dir, 'roc_curves')
                plot_dir = os.path.join(result_dir, 'tsne_pca')
                conv_dir = os.path.join(result_dir, 'convs')
                occ_dir = os.path.join(result_dir, 'occ_test')
                Xconv_dir =  os.path.join(result_dir, 'Xconv_dir')
                fakeXconv_dir = os.path.join(result_dir, 'fakeXconv_dir')
                Y_VGGconv_dir = os.path.join(result_dir, 'Y_VGGconv_dir')
                fakeY_VGGconv_dir = os.path.join(result_dir, 'fakeY_VGGconv_dir')

                rconv_dir = os.path.join(result_dir, 'resconvs')
                utils.prepare_dir(result_dir)
                utils.prepare_dir(occ_dir)
                utils.prepare_dir(fake_dir)
                utils.prepare_dir(roc_dir)
                utils.prepare_dir(plot_dir)
                utils.prepare_dir(conv_dir)
                utils.prepare_dir(rconv_dir)
                utils.prepare_dir(Xconv_dir)
                utils.prepare_dir(fakeXconv_dir)
                utils.prepare_dir(Y_VGGconv_dir)
                utils.prepare_dir(fakeY_VGGconv_dir)

                # x_image, x_label,oximage = get_batch_images(50,FLAGS.image_size,FLAGS.image_size,"./databet/X/0")
                # y_image, y_label,oyimage = get_batch_images(50,FLAGS.image_size,FLAGS.image_size,"./databet/Y/0")

                x_image, x_label, oximage = get_test_batch2("X", 1, FLAGS.image_size, FLAGS.image_size, "./dataset/")
                y_image, y_label, oyimage = get_test_batch2("Y", 1, FLAGS.image_size, FLAGS.image_size, "./dataset/")

                length3 = len(y_label)
                features_d = []
                fakeX_img =[]
                fakeY_img = []
                resconv_y_eval = []
                vggconv_x_eval = []
                X_Resconv_eval = []
                fakeX_Resconv_eval = []
                Y_VGGconv_eval = []
                fakeY_VGGconv_eval = []



                for i in range(length3):

                    ximgs = []
                    xlbs = []
                    yimgs=[]

                    ylbs=[]

                    ximgs.append(x_image[i])
                    xlbs.append(x_label[i])
                    yimgs.append(y_image[i])
                    ylbs.append(y_label[i])

                    # train

                    # SOFTMAX     roc
                    # fake_xpre   fake_xpre accuracy
                    # ffakeX[0]   tsne

                    softmax, fake_xpre, ffakeX, fakeX, fakeY ,X_Resconv,fakeX_Resconv , Y_VGGconv,fakeY_VGGconv= (sess.run(
                        [softmax3, fake_x_pre, f_fakeX, fake_x, fake_y, tf.get_collection('X_Resconv'),tf.get_collection('fakeX_Resconv'),tf.get_collection('Y_VGGconv'),tf.get_collection('fakeY_VGGconv')],
                        feed_dict={cycle_gan.x: ximgs, cycle_gan.y: yimgs,
                                   cycle_gan.x_label: xlbs, cycle_gan.y_label: ylbs}

                     )
                    )
                    Uy =  softmax[0]
                    fake_x_img = (np.array(fakeX[0]) + 1.0) * 127.5
                    fake_x_img = cv2.cvtColor(fake_x_img, cv2.COLOR_RGB2BGR)
                    fake_y_img = (np.array(fakeY[0]) + 1.0) * 127.5
                    fake_y_img = cv2.cvtColor(fake_y_img, cv2.COLOR_RGB2BGR)
                    fakeX_img.append(fake_x_img)
                    fakeY_img.append(fake_y_img)

                    X_Resconv_eval.append(X_Resconv)
                    fakeX_Resconv_eval.append(fakeX_Resconv)
                    Y_VGGconv_eval.append(Y_VGGconv)
                    fakeY_VGGconv_eval.append(fakeY_VGGconv)





                    features_d.append(ffakeX[0])


                # # T SNE    -----  PCA
                # print('features_d',len(features_d))
                # tsne = TSNE(n_components=2, learning_rate=4).fit_transform(features_d)
                # pca = PCA().fit_transform(features_d)
                # plt.figure(figsize=(12, 6))
                # plt.subplot(121)
                # plt.scatter(tsne[:, 0], tsne[:, 1], c=y_label)
                # plt.subplot(122)
                # plt.scatter(pca[:, 0], pca[:, 1], c=y_label)
                # plt.colorbar()  # 使用这一句就可以分辨出,颜色对应的类了!神奇啊。
                #
                # plt.savefig(os.path.join(plot_dir, 'plot.pdf'))



                # Cross Domain Image Generation#

                for i in range(length3):
                    file_nameOX = os.path.join(fake_dir, str(i) + '_oriX.png')
                    cv2.imwrite(file_nameOX, oximage[i])
                    file_name_fakeX = os.path.join(fake_dir, str(i) + '_fakeX.png')
                    cv2.imwrite(file_name_fakeX, fakeX_img[i])
                    file_nameOY = os.path.join(fake_dir, str(i) + '_oriY.png')
                    cv2.imwrite(file_nameOY, oyimage[i])
                    file_name_fakeY = os.path.join(fake_dir, str(i) + '_fakeY.png')
                    cv2.imwrite(file_name_fakeY, fakeY_img[i])



                # Feature Map Visualization   随机选取 10 张 生成fake_X
                width = height = 256
                vggconv = vggconv_x_eval
                for step in range(length3):


                    id_x_dir = os.path.join(Xconv_dir, str(step))
                    print('id_x_dir', id_x_dir)
                    for i, c in enumerate(X_Resconv_eval[step]):
                        plot_conv_output(c, i, id_x_dir)
                        print('Res%d' %i)
                    cv2.imwrite(os.path.join(id_x_dir, 'X.png'), oximage[step])

                for step in range(length3):
                    id_fakex_dir = os.path.join(fakeXconv_dir, str(step))
                    print('id_fakex_dir', id_fakex_dir)
                    for i, c in enumerate(fakeX_Resconv_eval[step]):
                        plot_conv_output(c, i, id_fakex_dir)
                        print('fakeRes%d' %i)
                    cv2.imwrite(os.path.join(id_fakex_dir, 'fakeX.png'), fakeX_img[step])




                for step in range(length3):
                    id_y_dir = os.path.join(Y_VGGconv_dir, str(step))
                    print('id_y_dir', id_y_dir)
                    # vgg = []
                    # vgg.append(vggconv_x_eval[step])
                    # print('shape vgg', np.shape(vggconv_x_eval[step][0]))
                    for i, c in enumerate(Y_VGGconv_eval[step]):
                        plot_conv_output(c, i, id_y_dir)
                        print('VGG%d' % i)
                    cv2.imwrite(os.path.join(id_y_dir, 'y.png'), oyimage[step])


                for step in range(length3):
                    id_fakey_dir = os.path.join(fakeY_VGGconv_dir, str(step))
                    print('id_fekey_dir', id_fakey_dir)
                    # vgg = []
                    # vgg.append(vggconv_x_eval[step])
                    # print('shape vgg', np.shape(vggconv_x_eval[step][0]))
                    for i, c in enumerate(fakeY_VGGconv_eval[step]):
                        plot_conv_output(c, i, id_fakey_dir)
                        print('fakeVGG%d' % i)
                    cv2.imwrite(os.path.join(id_fakey_dir, 'fakeY.png'), fakeY_img[step])








                image = y_image[1]
                width = height = 256



                occluded_size = 16
                data = np.empty((width * height + 1, width, height, 3), dtype="float32")
                print('data  ---')
                data[0, :, :, :] = image
                cnt = 1
                for i in range(height):
                    for j in range(width):
                        i_min = int(i - occluded_size / 2)
                        i_max = int(i + occluded_size / 2)
                        j_min = int(j - occluded_size / 2)
                        j_max = int(j + occluded_size / 2)
                        if i_min < 0:
                            i_min = 0
                        if i_max > height:
                            i_max = height
                        if j_min < 0:
                            j_min = 0
                        if j_max > width:
                            j_max = width
                        data[cnt, :, :, :] = image
                        data[cnt, i_min:i_max, j_min:j_max, :] = 255
                        # print(data[i].shape)
                        cnt += 1
                #
                # [idx_u]=np.where(np.max(Uy[id_y]))


                # [idx_u]=np.where(np.max(Uy))

                u_ys=np.empty([data.shape[0]],dtype='float64')
                occ_map=np.empty((width,height),dtype='float64')
                
                print('occ_map.shape',occ_map.shape)
                cnt=0
                feature_y_eval = sess.run(
                        softmax3,
                        feed_dict={cycle_gan.y: [data[0]]})#

                # print('softmax3',feature_y_eval.eval())

                
                u_y0 =   feature_y_eval[0]
                [idx_u]=np.where(np.max(u_y0))
                idx_u=idx_u[0]
                print('feature_y_eval',feature_y_eval)
                print('u_y0',u_y0)
                max = 0
                print('len u_y0',len(u_y0))
                for val in range(len(u_y0)):
                    vallist =    u_y0[val]
                    if   vallist> max:
                        max = vallist


                u_y0 = max
                # print('max', u_y0[idx_u])
                print('max', u_y0)
                # print('u_y01',u_y0[idx_u])

                for i in range(width):
                    for j in range(height):
                     feature_y_eval = sess.run(
                        softmax3,
                        feed_dict={cycle_gan.y: [data[cnt+1]]})
                     u_y = feature_y_eval[0]
                     # u_y =  max(u_y)
                     print('u_y',u_y)
                     u_y1 = 0
                     for val in range(len(u_y)):
                         vallist =   u_y[val]
                         if   vallist> u_y1:
                             u_y1 = vallist

                     occ_value=u_y0-u_y1
                     occ_map[i,j]=occ_value
                     print(str(cnt)+':'+str(occ_value))
                     cnt+=1

                occ_map_path=os.path.join(occ_dir,'occlusion_map_{}.txt'.format('1'))
                np.savetxt(occ_map_path, occ_map, fmt='%0.8f')
                cv2.imwrite(os.path.join(occ_dir, '{}.png'.format('1')), oyimage[1])
                draw_heatmap(occ_map_path=occ_map_path,ori_img=oyimage[1],save_dir=os.path.join(occ_dir,'heatmap_{}.png'.format('1')))

                exit()







        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            coord.request_stop(e)
        finally:

            # When done, ask the threads to stop.
            coord.request_stop()
            coord.join(threads)
Пример #27
0
def train(cli_params):
    cli_params['save_dir'] = prepare_dir(cli_params['save_to'])
    logfile = os.path.join(cli_params['save_dir'], 'log.txt')

    # Log also DEBUG to a file
    fh = logging.FileHandler(filename=logfile)
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)

    logger.info('Logging into %s' % logfile)

    p, loaded = load_and_log_params(cli_params)
    in_dim, data, whiten, cnorm = setup_data(p, test_set=False)
    if not loaded:
        # Set the zero layer to match input dimensions
        p.encoder_layers = (in_dim, ) + p.encoder_layers

    ladder = setup_model(p)

    # Training
    all_params = ComputationGraph([ladder.costs.total]).parameters
    logger.info('Found the following parameters: %s' % str(all_params))

    # Fetch all batch normalization updates. They are in the clean path.
    bn_updates = ComputationGraph([ladder.costs.class_clean]).updates
    assert 'counter' in [u.name for u in bn_updates.keys()], \
        'No batch norm params in graph - the graph has been cut?'

    training_algorithm = GradientDescent(
        cost=ladder.costs.total,
        parameters=all_params,
        step_rule=Adam(learning_rate=ladder.lr))
    # In addition to actual training, also do BN variable approximations
    training_algorithm.add_updates(bn_updates)

    short_prints = {
        "train": {
            'T_C_class': ladder.costs.class_corr,
            'T_C_de': ladder.costs.denois.values(),
        },
        "valid_approx":
        OrderedDict([
            ('V_C_class', ladder.costs.class_clean),
            ('V_E', ladder.error.clean),
            ('V_C_de', ladder.costs.denois.values()),
        ]),
        "valid_error":
        OrderedDict([
            ('VE_C_class', ladder.costs.class_clean),
            ('VE_E', ladder.error.clean),
            ('VE_C_de', ladder.costs.denois.values()),
        ]),
    }

    main_loop = MainLoop(
        training_algorithm,
        # Datastream used for training
        make_datastream(data.train,
                        data.train_ind,
                        p.batch_size,
                        n_labeled=p.labeled_samples,
                        n_unlabeled=p.unlabeled_samples,
                        whiten=whiten,
                        cnorm=cnorm),
        model=Model(ladder.costs.total),
        extensions=[
            #FinishAfter(after_n_epochs=p.num_epochs),

            # This will estimate the validation error using
            # running average estimates of the batch normalization
            # parameters, mean and variance
            #ApproxTestMonitoring(
            #    [ladder.costs.class_clean, ladder.error.clean]
            #    + ladder.costs.denois.values(),
            #    make_datastream(data.valid, data.valid_ind,
            #                    p.valid_batch_size, whiten=whiten, cnorm=cnorm,
            #                    scheme=ShuffledScheme),
            #    prefix="valid_approx"),

            # This Monitor is slower, but more accurate since it will first
            # estimate batch normalization parameters from training data and
            # then do another pass to calculate the validation error.
            FinalTestMonitoring(
                [ladder.costs.class_clean, ladder.error.clean] +
                ladder.costs.denois.values(),
                make_datastream(data.train,
                                data.train_ind,
                                p.batch_size,
                                n_labeled=p.labeled_samples,
                                whiten=whiten,
                                cnorm=cnorm,
                                scheme=ShuffledScheme),
                make_datastream(data.valid,
                                data.valid_ind,
                                p.valid_batch_size,
                                n_labeled=len(data.valid_ind),
                                whiten=whiten,
                                cnorm=cnorm,
                                scheme=ShuffledScheme),
                prefix="valid_error",
                every_n_epochs=1),
            TrainingDataMonitoring([
                ladder.costs.total, ladder.costs.class_corr,
                training_algorithm.total_gradient_norm
            ] + ladder.costs.denois.values(),
                                   prefix="train",
                                   after_epoch=True),

            #SaveParams(None, all_params, p.save_dir, after_epoch=True),
            #SaveExpParams(p, p.save_dir, before_training=True),
            #SaveLog(p.save_dir, after_training=True),
            #ShortPrinting(short_prints),
            #ProgressBar(),
            StopAfterNoImprovementValidation('valid_error_error_rate_clean',
                                             10, all_params, p, p.save_dir),
            #LRDecay(ladder.lr, p.num_epochs * p.lrate_decay, p.num_epochs,
            #        after_epoch=True),
        ])
    main_loop.run()

    # Get results
    df = DataFrame.from_dict(main_loop.log, orient='index')
    col = 'valid_error_error_rate_clean'
    logger.info('%s %g' % (col, df[col].iloc[-1]))

    if main_loop.log.status['epoch_interrupt_received']:
        return None
    return df
Пример #28
0
def main():
    parser = argparse.ArgumentParser()

    # Required Parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        help="bert OR lstm",
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output_result directory where the model predictions will be written."
    )
    parser.add_argument("--output_mode",
                        default="regression",
                        type=str,
                        help="classification or regression",
                        required=True)
    parser.add_argument("--domain",
                        default="celtrion",
                        type=str,
                        help="celtrion",
                        required=True)
    parser.add_argument("--target",
                        default="close",
                        type=str,
                        help="close, open, volume",
                        required=True)

    # Other Parameters
    parser.add_argument("--use_gpu",
                        help="use gpu=True or False",
                        default=True)
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for classifier.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for classifier.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval.")
    parser.add_argument("--window_size",
                        default=50,
                        type=int,
                        help="window size for lstm")

    args = parser.parse_args()

    # use GPU ?
    if not is_gpu_available():
        args.use_gpu = False

    # Model
    model_type = args.model_type
    output_mode = args.output_mode

    # data
    data_root = args.data_dir

    # output
    output_root = args.output_dir
    prepare_dir(output_root)

    fns = {
        'input': {
            'train': os.path.join(data_root, 'train.csv'),
            'test': os.path.join(data_root, 'test.csv')
        },
        'output': {
            # 'csv' : os.path.join() # 필요시에 ~~
        },
        'model': os.path.join(output_root, 'model.out')
    }

    # Train
    if args.do_train:
        hps = HParams(
            # domain -------------------------------------------
            domain=args.domain,
            target=args.target,

            # gpu setting ----------------------------------------
            use_gpu=args.use_gpu,

            # train settings ----------------------------------------
            learning_rate=args.learning_rate,
            num_train_epochs=args.num_train_epochs,
            per_gpu_train_batch_size=args.per_gpu_train_batch_size,
            window_size=args.window_size,

            # model settings ----------------------------------------
            model_type=model_type,
            output_mode=output_mode)

        hps.show()

        print("*********** Start Training ***********")
        run_train(fns['input']['train'], fns['model'], hps)

    if args.do_eval:
        print("*********** Start Evaluating ***********")

        batch_size = args.per_gpu_eval_batch_size

        run_eval(fns['input']['test'], fns['model'], batch_size)
Пример #29
0
def ansible_restore(cmds):

    if not (bool(cmds.path) ^ bool(cmds.s3)):
        raise Exception('Only one of --path or --s3 must be specified') 

    if not cmds.nodes:
        config = ConfigParser()
        if len(config.read('config.ini')) == 0:
            raise Exception('ERROR: Cannot find config.ini in script directory')
        nodes = re.findall('[^,\s\[\]]+', config.get('cassandra-info', 'hosts'))
        if not nodes:
            raise Exception('Hosts argument in config.ini not specified')
    else:
        nodes = cmds.nodes

    # prepare working directories
    temp_path = sys.path[0] + '/.temp'
    prepare_dir(sys.path[0] + '/output_logs', output=True)
    prepare_dir(temp_path, output=True)
    
    if cmds.path:
        zip_path = cmds.path
    elif cmds.s3:
        s3 = s3_bucket()
        s3_snapshots = s3_list_snapshots(s3)

        if cmds.s3 == True: # not a string parameter
            if len(s3_snapshots) == 0:
                print('No snapshots found in s3')
                exit(0)

            # search 
            print('\nSnapshots found:')
            template = '{0:5} | {1:67}'
            print(template.format('Index', 'Snapshot'))
            for idx, snap in enumerate(s3_snapshots):
                # every snapshot starts with cassandra-snapshot- (19 chars)
                stripped = snap[19:] 
                print(template.format(idx + 1, stripped))
            
            index = 0
            while index not in range(1, len(s3_snapshots) + 1):
                try:
                    index = int(raw_input('Enter snapshot index: '))
                except ValueError:
                    continue
            s3_key = s3_snapshots[index - 1]

        else:
            s3_key = cmds.s3
            if not s3_key.startswith('cassandra-snapshot-'):
                s3_key = 'cassandra-snapshot-' + s3_key

            if s3_key not in s3_snapshots:
                raise Exception('S3 Snapshot not found')

        print('Retrieving snapshot from S3: %s' % s3_key)
        s3.download_file(s3_key, temp_path + '/temp.zip') 
        zip_path = temp_path + '/temp.zip'
    else:
        raise Exception('No file specified.')

    # unzip 
    print('Unzipping snapshot file')
    z = zipfile.ZipFile(zip_path, 'r')
    z.extractall(temp_path)

    # check schema specification args
    print('Checking arguments . . .')
    restore_command = 'restore.py '
    load_schema_command = 'load_schema.py '
    if cmds.keyspace:

        schema = get_zipped_schema(temp_path + '/schemas.zip')
        for keyspace in cmds.keyspace:
            if keyspace not in schema.keys():
                raise Exception('ERROR: Keyspace "%s" not in snapshot schema' % keyspace)

        keyspace_arg = '-ks ' + ' '.join(cmds.keyspace)
        restore_command += keyspace_arg
        load_schema_command += keyspace_arg
                
        if cmds.table:

            if len(cmds.keyspace) != 1:
                raise Exception('ERROR: One keyspace must be specified with table argument')

            ks = cmds.keyspace[0]
            for tb in cmds.table:
                if tb not in schema[ks]:
                    raise Exception('ERROR: Table "%s" not found in keyspace "%s"' % (tb, ks))

            restore_command += ' -tb ' + ' '.join(cmds.table)

    elif cmds.table:
        raise Exception('ERROR: Keyspace must be specified with tables')

    playbook_args = {
        'nodes': ' '.join(nodes),
        'restore_command' : restore_command,
        'load_schema_command' : load_schema_command,
        'reload' : cmds.reload,
        'hard_reset' : cmds.hard_reset
    }
    return_code = run_playbook('restore.yml', playbook_args)
    
    if return_code != 0:
        print('ERROR: Ansible script failed to run properly. ' +
              'If this persists, try --hard-reset. (TODO)') # TODO
    else:
        print('Process complete.')
        print('Output logs saved in %s' % (sys.path[0] + '/output_logs'))