Ejemplo n.º 1
0
def get_samples(target, nb_class=10, sample_index=0):
    '''
    params:
        target : [mnist, cifar10]
        nb_class : number of classes
        example_index : index of image by class

    returns:
        original_images (numpy array) : Original images, shape = (number of class, W, H, C)
        pre_images (torch array) : Preprocessing images, shape = (number of class, C, W, H)
        target_classes (dictionary) : keys = class index, values = class name
        model (pytorch model) : pretrained model
    '''

    if target == 'mnist':
        image_size = (28, 28, 1)

        _, _, testloader = mnist_load()
        testset = testloader.dataset

    elif target == 'cifar10':
        image_size = (32, 32, 3)

        _, _, testloader = cifar10_load()
        testset = testloader.dataset

    # idx2class
    target_class2idx = testset.class_to_idx
    target_classes = dict(
        zip(list(target_class2idx.values()), list(target_class2idx.keys())))

    # select images
    idx_by_class = [
        np.where(np.array(testset.targets) == i)[0][sample_index]
        for i in range(nb_class)
    ]
    original_images = testset.data[idx_by_class]
    if not isinstance(original_images, np.ndarray):
        original_images = original_images.numpy()
    original_images = original_images.reshape((nb_class, ) + image_size)
    # select targets
    if isinstance(testset.targets, list):
        original_targets = torch.LongTensor(testset.targets)[idx_by_class]
    else:
        original_targets = testset.targets[idx_by_class]

    # model load
    weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(target))
    model = SimpleCNN(target)
    model.load_state_dict(weights['model'])

    # image preprocessing
    pre_images = torch.zeros(original_images.shape)
    pre_images = np.transpose(pre_images, (0, 3, 1, 2))
    for i in range(len(original_images)):
        pre_images[i] = testset.transform(original_images[i])

    return original_images, original_targets, pre_images, target_classes, model
Ejemplo n.º 2
0
class ParameterServer(object):
    def __init__(self, learning_rate):
        self.net = SimpleCNN(learning_rate=learning_rate)

    def apply_gradients(self, *gradients):
        self.net.apply_gradients(np.mean(gradients, axis=0))
        return self.net.variables.get_flat()

    def get_weights(self):
        return self.net.variables.get_flat()
Ejemplo n.º 3
0
class Worker(object):
    def __init__(self, worker_index, batch_size=50):
        self.worker_index = worker_index
        self.batch_size = batch_size
        self.mnist = download_mnist_retry(seed=worker_index)
        self.net = SimpleCNN()

    def compute_gradients(self, weights):
        self.net.variables.set_flat(weights)
        xs, ys = self.mnist.train.next_batch(self.batch_size)
        return self.net.compute_gradients(xs, ys)
Ejemplo n.º 4
0
 def __init__(self, dataset='mnist', model='simplecnn', **kwargs):
     if dataset == 'mnist':
         self.dataset = Mnist()
     elif dataset == 'cifar10':
         self.dataset = Cifar10()
     else:
         raise NotImplementedError
     if model == 'simplecnn':
         self.model = SimpleCNN(hyper_mode=True)
     elif model == 'resnet50':
         self.model = Resnet50(hyper_mode=True, num_classes=self.dataset.num_classes)
     else:
         raise NotImplementedError
     self.x_dim = kwargs.pop('x_dim', 28)
     self.c_dim = kwargs.pop('c_dim', 1)
     self.num_classes = kwargs.pop('num_classes', 10)
     self.batch_size = kwargs.pop('batch_size', 1024)
     self.max_epoch = kwargs.pop('max_epoch', 50)
     self.learning_rate = kwargs.pop('learning_rate', 0.0005)
     self.lr_decay = kwargs.pop('lr_decay', 0.99)
     self.grad_clip = kwargs.pop('grad_clip', 100.0)
     self.optimize_method = kwargs.pop('optimizer', 'adam')
     self.logpath = kwargs.pop('logpath', 'log')
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    # Make this case-insensitive?
    parser.add_argument('--model',
                        type=str,
                        required=True,
                        choices=model_list,
                        help='Model to use')
    parser.add_argument('--config',
                        type=str,
                        default='config/default.yaml',
                        help='Config file to use')
    parser.add_argument('--checkpoint',
                        '-chkp',
                        type=str,
                        help='Checkpoint of the model to use')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='Use cpu only (Overrides some hyperparameters)')
    args = parser.parse_args()

    hp = OmegaConf.load(args.config)
    if args.cpu:
        hp.train.batch_size = 16

    if args.model == 'SimpleCNN':
        model = SimpleCNN(hp=hp)
    elif args.model == 'SATNet':
        model = SATNet(hp=hp)
    else:
        raise RuntimeError("Wrong model name %s" % args.model)

    train_name = '%s_%s' % (args.model, hp.data.name)
    logger = loggers.TensorBoardLogger('logs/', name=train_name)
    logger.log_hyperparams(OmegaConf.to_container(hp))

    trainer = Trainer(
        gpus=None if args.cpu else -1,
        logger=logger,
        resume_from_checkpoint=args.checkpoint,
        max_epochs=100000,
    )
    trainer.fit(model)
    trainer.test(model)
Ejemplo n.º 6
0
def train_model(train_glob: str, checkpoint_root: str, tensorboard_root: str):

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    fs = gcsfs.GCSFileSystem(token='cloud')

    logger.log(logging.INFO, 'Creating model')
    model = SimpleCNN()

    logger.log(logging.INFO, 'Opening files from: '.format(train_glob))
    dataset = pq.ParquetDataset(train_glob, filesystem=fs)

    logger.log(logging.INFO, 'Creating dataset')
    train_dataset = IterableParquetDataset(
        dataset,
        32,
        process_func=process_image,
        columns=[
            'image/class/label',  # TODO: should these be hard-coded...?
            'image/encoded'
        ])

    logger.log(logging.INFO, 'Creating data loader')
    dataloader = DataLoader(train_dataset)

    #tboard = TensorBoardLogger(tensorboard_root)
    checkpoint_callback = ModelCheckpoint(
        filepath=os.path.join(checkpoint_root, 'weights.ckpt'))

    logger.log(logging.INFO, 'Starting training')
    trainer = pl.Trainer(
        #logger=tboard,
        checkpoint_callback=checkpoint_callback,
        max_epochs=1)

    trainer.fit(model, dataloader)

    return checkpoint_callback.best_model_path
def worker_task(ps, worker_index, batch_size=50):
    # Download MNIST.
    print("Worker " + str(worker_index))
    mnist = download_mnist_retry(seed=worker_index)

    # Initialize the model.
    net = SimpleCNN()
    keys = net.get_weights()[0]

    while True:
        # Get the current weights from the parameter server.
        weights = ray.get(ps.pull.remote(keys))
        net.set_weights(keys, weights)
        # Compute an update and push it to the parameter server.
        xs, ys = mnist.train.next_batch(batch_size)
        gradients = net.compute_update(xs, ys)
        ps.push.remote(keys, gradients)
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('datapath', type=str, help='Path to files to infer')
    parser.add_argument('--checkpoint',
                        '-chkp',
                        type=str,
                        required=True,
                        help='Checkpoint of the model to use')
    parser.add_argument('--model',
                        type=str,
                        default='SATNet',
                        help='Path to files to infer')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='Use cpu for inference')
    args = parser.parse_args()

    if args.model == 'SimpleCNN':
        model = SimpleCNN.load_from_checkpoint(args.checkpoint)
    elif args.model == 'SATNet':
        model = SATNet.load_from_checkpoint(args.checkpoint)
    else:
        raise RuntimeError("Wrong model name %s" % args.model)
    model.freeze()

    train_name = 'infer_%s' % args.model
    logger = loggers.TensorBoardLogger('logs/', name=train_name)

    dataset = DemoDataset(args.datapath)
    loader = DataLoader(dataset, num_workers=64, shuffle=False)

    trainer = Trainer(gpus=None if args.cpu else -1,
                      logger=logger,
                      callbacks=[TestCallback()])

    trainer.test(model, test_dataloaders=loader)
    print("Done!")
Ejemplo n.º 9
0
def main():
    args = parse_args()

    with open(args.config_file, 'r') as f:
        cfg = yaml.load(f)

    # load dataset
    (train_imgs, train_labels), (test_imgs, test_labels) = load_dataset()
    print("Train image shape: ", train_imgs.shape)
    print("Test image shape: ", test_imgs.shape)

    # train
    cnn_model = SimpleCNN(cfg['global']['tmp_path'])
    cnn_model.train(cfg['global']['epoch'], train_imgs, train_labels,
                    test_imgs, test_labels)

    # savedModelに書き出し
    cnn_model.export_SavedModel(cfg['tf-serving']['version_number'],
                                cfg['tf-serving']['export_dir'],
                                weight_path=os.path.join(
                                    cfg['global']['tmp_path'], "model.h5"))
Ejemplo n.º 10
0
        # model_path = os.path.join(args.train_dir, 'checkpoint_%d.pth.tar' % args.inference_version)
        # if os.path.exists(model_path):
        # 	cnn_model = torch.load(model_path)

        avg_val_acc = 0.0
        avg_val_roc = 0.0

        best_val_acc_global = 0.0
        best_val_roc_global = 0.0

        for i in range(5):
            if args.model == "parallel":
                model = ParallelCNN(max_len, drop_rate=args.drop_rate)
            elif args.model == 'cnn':
                model = SimpleCNN(max_len, drop_rate=args.drop_rate)
            elif args.model == 'siamese':
                model = SiameseCNN(max_len, drop_rate=args.drop_rate)
            model.to(device)
            optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

            pre_losses = [1e18] * 3
            best_val_acc = 0.0
            best_val_roc = 0.0

            X_train = []
            y_train = []
            for j in range(5):
                if (j != i):
                    X_train.extend(fold_list[j][0])
                    y_train.extend(fold_list[j][1])
Ejemplo n.º 11
0
 def __init__(self, learning_rate):
     self.net = SimpleCNN(learning_rate=learning_rate)
Ejemplo n.º 12
0
 def __init__(self, worker_index, batch_size=50):
     self.worker_index = worker_index
     self.batch_size = batch_size
     self.mnist = download_mnist_retry(seed=worker_index)
     self.net = SimpleCNN()
Ejemplo n.º 13
0
class Solver(object):
    def __init__(self, dataset='mnist', model='simplecnn', **kwargs):
        if dataset == 'mnist':
            self.dataset = Mnist()
        elif dataset == 'cifar10':
            self.dataset = Cifar10()
        else:
            raise NotImplementedError
        if model == 'simplecnn':
            self.model = SimpleCNN(hyper_mode=True)
        elif model == 'resnet50':
            self.model = Resnet50(hyper_mode=True, num_classes=self.dataset.num_classes)
        else:
            raise NotImplementedError
        self.x_dim = kwargs.pop('x_dim', 28)
        self.c_dim = kwargs.pop('c_dim', 1)
        self.num_classes = kwargs.pop('num_classes', 10)
        self.batch_size = kwargs.pop('batch_size', 1024)
        self.max_epoch = kwargs.pop('max_epoch', 50)
        self.learning_rate = kwargs.pop('learning_rate', 0.0005)
        self.lr_decay = kwargs.pop('lr_decay', 0.99)
        self.grad_clip = kwargs.pop('grad_clip', 100.0)
        self.optimize_method = kwargs.pop('optimizer', 'adam')
        self.logpath = kwargs.pop('logpath', 'log')

    def train(self):
        x_train, y_train = self.dataset.x_train, self.dataset.y_train
        x_test, y_test = self.dataset.x_test, self.dataset.y_test

        batch_images = tf.placeholder(dtype=tf.float32, shape=[None, self.x_dim, self.x_dim, self.c_dim])
        batch_labels = tf.placeholder(dtype=tf.float32, shape=[None, self.num_classes])
        loss_op, prob_op, predict_op = self.model.build_model(batch_images, batch_labels)

        n_samples = len(x_train)
        n_iterations = int(ceil(n_samples / float(self.batch_size)))
        # learning rate
        global_step = tf.Variable(0, trainable=False)
        lr = tf.train.exponential_decay(self.learning_rate, global_step=global_step, decay_steps=n_iterations,
                                        decay_rate=self.lr_decay)
        if self.optimize_method == 'adam':
            optimizer = tf.train.AdamOptimizer(lr)
        else:
            raise NotImplementedError
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(loss_op, tvars), self.grad_clip)
        train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step)

        for var in tvars:
            tf.summary.histogram(var.op.name, var)
        for grad, var in zip(grads, tvars):
            if grad is not None:
                tf.summary.histogram(var.op.name + '/gradient', grad)
        # Create a summary to monitor cost tensor
        tf.summary.scalar("loss", loss_op)
        # Merge all summaries into a single op
        merged_summary_op = tf.summary.merge_all()

        sess = tf.Session()
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        summary_writer = tf.summary.FileWriter(self.logpath, graph=tf.get_default_graph())
        for epoch in range(self.max_epoch):
            # shuffle data
            permutation = np.random.permutation(n_samples)
            x_train = x_train[permutation]
            y_train = y_train[permutation]
            for i in range(n_iterations):
                batch_x = x_train[i * self.batch_size:(i + 1) * self.batch_size]
                batch_y = y_train[i * self.batch_size:(i + 1) * self.batch_size]
                _, loss, summary_str = sess.run([train_op, loss_op, merged_summary_op],
                                                feed_dict={batch_images: batch_x, batch_labels: batch_y})
                summary_writer.add_summary(summary_str, epoch * n_iterations + i)
            train_loss, train_acc = self.evaluate_in_batch(x_train, y_train, sess, loss_op, predict_op, batch_images,
                                                           batch_labels)
            test_loss, test_acc = self.evaluate_in_batch(x_test, y_test, sess, loss_op, predict_op, batch_images,
                                                         batch_labels)
            learn_rate = sess.run(optimizer._lr)
            print('Epoch %3d: train loss %.6f, train acc %.6f, test loss %.6f, test acc %.6f, lr %.6f'
                  % (epoch, train_loss, train_acc, test_loss, test_acc, learn_rate))

    def evaluate_in_batch(self, x, y, sess, loss_op, predict_op, x_placeholder, y_placeholder):
        n_samples = len(x)
        n_iterations = int(ceil(n_samples / float(self.batch_size)))
        y_pred = np.zeros([y.shape[0]])
        losses = []
        for i in range(n_iterations):
            batch_x = x[i * self.batch_size:(i + 1) * self.batch_size]
            batch_y = y[i * self.batch_size:(i + 1) * self.batch_size]
            loss, ans = sess.run([loss_op, predict_op], feed_dict={x_placeholder: batch_x, y_placeholder: batch_y})
            losses.append(loss)
            y_pred[i * self.batch_size:(i + 1) * self.batch_size] = ans
        accuaracy = np.mean(np.equal(y_pred, np.argmax(y, axis=1)).astype(np.float32))
        return np.sum(losses) / n_samples, accuaracy
Ejemplo n.º 14
0
        print("fit")
        self.model.fit(
            x=self.train_dataset,
            steps_per_epoch=self.train_size//batch_size,
            validation_data=self.validation_dataset,
            validation_steps=self.validation_size//batch_size,
            epochs=epochs
        )
#            callbacks=callback
#        )

if __name__ =="__main__":
    from dataset import MnistDataset
    from dataset import DatasetUtil
    from model import SimpleCNN
    from model import SimpleSoftmaxClassificationModel

    train = MnistDataset.get_train_set().map(DatasetUtil.image_classification_util(10))
    validation = MnistDataset.get_validation_set().map(DatasetUtil.image_classification_util(10))
    train_size, validation_size, _  = MnistDataset.get_data_size()
    base = SimpleCNN.get_base_model(28,28,1)
    softmax_model = SimpleSoftmaxClassificationModel.get_classification_model(base,10)
    trainer:Trainer = Trainer(
        softmax_model,
        train=train,
        train_size=train_size,
        validation=validation,
        validation_size=validation_size
    )
    trainer.train()
Ejemplo n.º 15
0
def main(args, **kwargs):
    # Config
    epochs = args.epochs
    batch_size = args.batch_size
    valid_rate = args.valid_rate
    lr = args.lr
    verbose = args.verbose

    # checkpoint
    target = args.target
    monitor = args.monitor
    mode = args.mode

    # save name
    model_name = 'simple_cnn_{}'.format(target)

    # save directory
    savedir = '../checkpoint'
    logdir = '../logs'

    # device setting cpu or cuda(gpu)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('=====Setting=====')
    print('Epochs: ', epochs)
    print('Batch Size: ', batch_size)
    print('Validation Rate: ', valid_rate)
    print('Learning Rate: ', lr)
    print('Target: ', target)
    print('Monitor: ', monitor)
    print('Model Name: ', model_name)
    print('Mode: ', mode)
    print('Save Directory: ', savedir)
    print('Log Directory: ', logdir)
    print('Device: ', device)
    print('Verbose: ', verbose)
    print()
    print('Setting Random Seed')
    print()
    seed_everything()  # seed setting

    # Data Load
    print('=====Data Load=====')
    if target == 'mnist':
        trainloader, validloader, testloader = mnist_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    elif target == 'cifar10':
        trainloader, validloader, testloader = cifar10_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    # ROAR or KAR
    if (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # saliency map load
        hf = h5py.File(
            f'../saliency_maps/[{args.target}]{args.method}_train.hdf5', 'r')
        sal_maps = np.array(hf['saliencys'])
        # adjust image
        data_lst = adjust_image(kwargs['ratio'], trainloader, sal_maps,
                                args.eval)
        # hdf5 close
        hf.close()
        # model name
        model_name = model_name + '_{0:}_{1:}{2:.1f}'.format(
            args.method, args.eval, kwargs['ratio'])

    print('=====Model Load=====')
    # Load model
    net = SimpleCNN(target).to(device)
    print()

    # Model compile
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss()

    # Train
    modeltrain = ModelTrain(model=net,
                            data=trainloader,
                            epochs=epochs,
                            criterion=criterion,
                            optimizer=optimizer,
                            device=device,
                            model_name=model_name,
                            savedir=savedir,
                            monitor=monitor,
                            mode=mode,
                            validation=validloader,
                            verbose=verbose)
    # Test
    modeltest = ModelTest(model=net,
                          data=testloader,
                          loaddir=savedir,
                          model_name=model_name,
                          device=device)

    modeltrain.history['test_result'] = modeltest.results

    # History save as json file
    if not (os.path.isdir(logdir)):
        os.mkdir(logdir)
    with open(f'{logdir}/{model_name}_logs.txt', 'w') as outfile:
        json.dump(modeltrain.history, outfile)
Ejemplo n.º 16
0
            object_store_memory=args.object_store_memory,
            additional_archive="MNIST_data.zip#MNIST_data")
        ray_ctx = OrcaContext.get_ray_context()
    elif cluster_mode == "local":
        sc = init_orca_context(cores=args.driver_cores)
        ray_ctx = OrcaContext.get_ray_context()
    elif cluster_mode == "spark-submit":
        sc = init_orca_context(cluster_mode=cluster_mode)
        ray_ctx = OrcaContext.get_ray_context()
    else:
        print(
            "init_orca_context failed. cluster_mode should be one of 'local', 'yarn' and 'spark-submit' but got "
            + cluster_mode)

    # Create a parameter server with some random weights.
    net = SimpleCNN()
    all_keys, all_values = net.get_weights()
    ps = ParameterServer.remote(all_keys, all_values)

    # Start some training tasks.
    worker_tasks = [worker_task.remote(ps, i) for i in range(args.num_workers)]

    # Download MNIST.
    mnist = download_mnist_retry()
    print("Begin iteration")
    i = 0
    while i < args.iterations:
        # Get and evaluate the current model.
        print("-----Iteration" + str(i) + "------")
        current_weights = ray.get(ps.pull.remote(all_keys))
        net.set_weights(all_keys, current_weights)
Ejemplo n.º 17
0
    # TODO: Tensorboard Check

    # python main.py --train --target=['mnist','cifar10'] --attention=['CAM','CBAM','RAN','WARN']
    if args.train:
        main(args=args)

    elif args.eval == 'selectivity':
        # make evalutation directory
        if not os.path.isdir('../evaluation'):
            os.mkdir('../evaluation')

        # pretrained model load
        weights = torch.load('../checkpoint/simple_cnn_{}.pth'.format(
            args.target))
        model = SimpleCNN(args.target)
        model.load_state_dict(weights['model'])

        # selectivity evaluation
        selectivity_method = Selectivity(model=model,
                                         target=args.target,
                                         batch_size=args.batch_size,
                                         method=args.method,
                                         sample_pct=args.ratio)
        # evaluation
        selectivity_method.eval(steps=args.steps, savedir='../evaluation')

    elif (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # ratio
        ratio_lst = np.arange(0, 1, args.ratio)[1:]  # exclude zero
        for ratio in ratio_lst:
Ejemplo n.º 18
0
def main(args, **kwargs):
    #################################
    # Config
    #################################
    epochs = args.epochs
    batch_size = args.batch_size
    valid_rate = args.valid_rate
    lr = args.lr
    verbose = args.verbose

    # checkpoint
    target = args.target
    attention = args.attention
    monitor = args.monitor
    mode = args.mode

    # save name
    model_name = 'simple_cnn_{}'.format(target)
    if attention in ['CAM', 'CBAM']:
        model_name = model_name + '_{}'.format(attention)
    elif attention in ['RAN', 'WARN']:
        model_name = '{}_{}'.format(target, attention)

    # save directory
    savedir = '../checkpoint'
    logdir = '../logs'

    # device setting cpu or cuda(gpu)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('=====Setting=====')
    print('Training: ', args.train)
    print('Epochs: ', epochs)
    print('Batch Size: ', batch_size)
    print('Validation Rate: ', valid_rate)
    print('Learning Rate: ', lr)
    print('Target: ', target)
    print('Monitor: ', monitor)
    print('Model Name: ', model_name)
    print('Mode: ', mode)
    print('Attention: ', attention)
    print('Save Directory: ', savedir)
    print('Log Directory: ', logdir)
    print('Device: ', device)
    print('Verbose: ', verbose)
    print()
    print('Evaluation: ', args.eval)
    if args.eval != None:
        print('Pixel ratio: ', kwargs['ratio'])
    print()
    print('Setting Random Seed')
    print()
    seed_everything()  # seed setting

    #################################
    # Data Load
    #################################
    print('=====Data Load=====')
    if target == 'mnist':
        trainloader, validloader, testloader = mnist_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    elif target == 'cifar10':
        trainloader, validloader, testloader = cifar10_load(
            batch_size=batch_size, validation_rate=valid_rate, shuffle=True)

    #################################
    # ROAR or KAR
    #################################
    if (args.eval == 'ROAR') or (args.eval == 'KAR'):
        # saliency map load
        filename = f'../saliency_maps/[{args.target}]{args.method}'
        if attention in ['CBAM', 'RAN']:
            filename += f'_{attention}'
        hf = h5py.File(f'{filename}_train.hdf5', 'r')
        sal_maps = np.array(hf['saliencys'])
        # adjust image
        data_lst = adjust_image(kwargs['ratio'], trainloader, sal_maps,
                                args.eval)
        # hdf5 close
        hf.close()
        # model name
        model_name = model_name + '_{0:}_{1:}{2:.1f}'.format(
            args.method, args.eval, kwargs['ratio'])

    # check exit
    if os.path.isfile('{}/{}_logs.txt'.format(logdir, model_name)):
        sys.exit()

    #################################
    # Load model
    #################################
    print('=====Model Load=====')
    if attention == 'RAN':
        net = RAN(target).to(device)
    elif attention == 'WARN':
        net = WideResNetAttention(target).to(device)
    else:
        net = SimpleCNN(target, attention).to(device)
    n_parameters = sum([np.prod(p.size()) for p in net.parameters()])
    print('Total number of parameters:', n_parameters)
    print()

    # Model compile
    optimizer = optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss()

    #################################
    # Train
    #################################
    modeltrain = ModelTrain(model=net,
                            data=trainloader,
                            epochs=epochs,
                            criterion=criterion,
                            optimizer=optimizer,
                            device=device,
                            model_name=model_name,
                            savedir=savedir,
                            monitor=monitor,
                            mode=mode,
                            validation=validloader,
                            verbose=verbose)

    #################################
    # Test
    #################################
    modeltest = ModelTest(model=net,
                          data=testloader,
                          loaddir=savedir,
                          model_name=model_name,
                          device=device)

    modeltrain.history['test_result'] = modeltest.results

    # History save as json file
    if not (os.path.isdir(logdir)):
        os.mkdir(logdir)
    with open(f'{logdir}/{model_name}_logs.txt', 'w') as outfile:
        json.dump(modeltrain.history, outfile)
Ejemplo n.º 19
0
            extra_python_lib=args.extra_python_lib,
            additional_archive="MNIST_data.zip#MNIST_data")
        ray_ctx = OrcaContext.get_ray_context()
    elif cluster_mode == "local":
        sc = init_orca_context(cores=args.driver_cores)
        ray_ctx = OrcaContext.get_ray_context()
    elif cluster_mode == "spark-submit":
        sc = init_orca_context(cluster_mode=cluster_mode)
        ray_ctx = OrcaContext.get_ray_context()
    else:
        print(
            "init_orca_context failed. cluster_mode should be one of 'local', 'yarn' and 'spark-submit' but got "
            + cluster_mode)

    # Create a parameter server.
    net = SimpleCNN()
    ps = ParameterServer.remote(1e-4 * args.num_workers)

    # Create workers.
    workers = [
        Worker.remote(worker_index) for worker_index in range(args.num_workers)
    ]

    # Download MNIST.
    mnist = download_mnist_retry()

    i = 0
    current_weights = ps.get_weights.remote()
    print("Begin iteration")
    while i < args.iterations:
        # Compute and apply gradients.