예제 #1
0
 def __init__(self,
              directory,
              epochs=1,
              cuda=False,
              save=False,
              log_interval=30,
              load=None,
              split=(0.6, 0.2, 0.2),
              cache=False,
              minibatch_size=10,
              pretrained=False):
     self.dataset = Dataset(directory,
                            split=split,
                            cache=cache,
                            minibatch_size=minibatch_size)
     self.epochs = epochs
     self.cuda = cuda
     self.save = save
     self.log_interval = log_interval
     self.model = DenseNet(pretrained)
     self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
     if load is not None:
         state = torch.load(load)
         self.model.load_state_dict(state['model'])
         self.optimizer.load_state_dict(state['optim'])
     if cuda:
         self.model = self.model.cuda()
예제 #2
0
    def build_model(self):
        """
        Instantiates the model, loss criterion, and optimizer
        """

        # instantiate model
        self.model = DenseNet(config=self.config,
                              channels=self.input_channels,
                              class_count=self.class_count,
                              num_features=self.num_features,
                              compress_factor=self.compress_factor,
                              expand_factor=self.expand_factor,
                              growth_rate=self.growth_rate)

        # instantiate loss criterion
        self.criterion = nn.CrossEntropyLoss()

        # instantiate optimizer
        self.optimizer = optim.SGD(params=self.model.parameters(),
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay,
                                   nesterov=True)

        # print network
        self.print_network(self.model, 'DenseNet')

        # use gpu if enabled
        if torch.cuda.is_available() and self.use_gpu:
            self.model.cuda()
            self.criterion.cuda()
예제 #3
0
파일: main.py 프로젝트: vmelan/DenseNet-tf2
def main(config):
    # Load CIFAR data
    data = DataLoader(config)
    train_loader, test_loader = data.prepare_data()

    model = DenseNet(config)

    model.build((config["trainer"]["batch_size"], 224, 224, 3))
    print(model.summary())

    optimizer = tf.keras.optimizers.Adam(lr=0.001)
    loss_object = tf.keras.losses.CategoricalCrossentropy()
    train_loss = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
    train_accuracy = tf.keras.metrics.CategoricalAccuracy(
        name='train_accuracy')

    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images)
            loss = loss_object(labels, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_loss(loss)
        train_accuracy(labels, predictions)

    for epoch in range(config["trainer"]["epochs"]):
        for step, (images, labels) in tqdm(
                enumerate(train_loader),
                total=int(len(data) / config["trainer"]["batch_size"])):
            train_step(images, labels)
        template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.4f}'
        print(
            template.format(epoch + 1, train_loss.result(),
                            train_accuracy.result() * 100))
예제 #4
0
def main():
    args = get_args()
    logger = get_logger(args.logdir)

    logger.info(vars(args))

    with tf.device("/gpu:0"):
        if args.densenet:
            model = DenseNet()
        else:
            model = CnnLstm()

    tfrecord = TFRecord()
    tfrecord.make_iterator(args.tfr_fname)

    total = sum(1 for _ in tf.python_io.tf_record_iterator(args.tfr_fname)) // tfrecord.batch_size

    saver = tf.train.Saver(tf.global_variables())

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        writer = tf.summary.FileWriter(args.logdir, sess.graph)

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        for epoch in range(args.num_epochs):
            logger.info(f"Epoch {epoch+1}")

            sess.run(tfrecord.init_op)
            loss_list = []

            progress_bar = tqdm(total=total, desc="[TRAIN] Loss: 0", unit="batch", leave=False)

            while True:
                try:
                    step = sess.run(model.global_step)

                    spec, label = tfrecord.load(sess, training=True)

                    _, loss, merged = model.train(sess, spec, label)

                    progress_bar.update(1)
                    progress_bar.set_description(f"[TRAIN] Batch Loss: {loss:.4f}")

                    loss_list.append(loss)

                    writer.add_summary(summary=merged, global_step=step)

                except tf.errors.OutOfRangeError:
                    break

            progress_bar.close()

            mean_loss = np.mean(loss_list)
            logger.info(f"  -  [TRAIN] Mean Loss: {mean_loss:.4f}")

            saver.save(sess, args.logdir + "/" + args.save_fname + ".ckpt", global_step=sess.run(model.global_step))
예제 #5
0
파일: modellist.py 프로젝트: AhnYoungBin/ex
 def __call__(self, x, seon):
     return {
         1: Vgg.VGG11(seon),
         2: Vgg.VGG13(seon),
         3: Vgg.VGG16(seon),
         4: Vgg.VGG19(seon),
         5: ResNet.ResNet18(seon),
         6: ResNet.ResNet34(seon),
         7: ResNet.ResNet50(seon),
         8: ResNet.ResNet101(seon),
         9: ResNet.ResNet152(seon),
         10: DenseNet.DenseNet121(seon),
         11: DenseNet.DenseNet169(seon),
         12: DenseNet.DenseNet201(seon),
         13: DenseNet.DenseNet161(seon)
     }[x]
예제 #6
0
def model_fn(features,
             labels,
             mode,
             params):

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    all_config = params["config"]

    logits = DenseNet(features, all_config.GROWTH_RATE,
                                all_config.DEPTH,
                                all_config.NUM_DENSE_BLOCK,
                                all_config.NUM_INIT_FILTER,
                                all_config.SUB_SAMPLE_IMAGE,
                                all_config.NUM_CLASSES,
                                training=is_training,
                                bottleneck=all_config.BOTTLENECK,
                                dropout_rate=all_config.DROPOUT_RATES,
                                compression=all_config.COMPRESSION,
                                data_format=all_config.DATA_FORMAT,
                                all_config=all_config)

    with tf.variable_scope("loss"):
        classifier_loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)

        regularization_list = [tf.reduce_sum(all_config.WEIGHT_DECAY * tf.square(w.read_value()))
                               for w in tf.trainable_variables()]
        regularization_loss = tf.add_n(regularization_list)

        total_loss = classifier_loss + regularization_loss
    global_step = tf.train.get_or_create_global_step()
    lr = tf.train.piecewise_constant(global_step,
                                     boundaries=[np.int64(all_config.BOUNDARY[0]), np.int64 (all_config.BOUNDARY[1])],
                                     values=[all_config.INIT_LEARNING_RATE, all_config.INIT_LEARNING_RATE / 10,
                                             all_config.INIT_LEARNING_RATE / 100])
    tf.summary.scalar('learning_rate', lr)
    optimizer = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies([tf.group(*update_ops)]):
        train_op = optimizer.minimize(total_loss, global_step)

    predictions = tf.math.argmax(tf.nn.softmax(logits, axis=-1), axis=-1)
    accuracies, update_accuracies = tf.metrics.accuracy(labels, predictions)

    meta_hook = MetadataHook(save_steps=all_config.SAVE_EVERY_N_STEP*all_config.EPOCH/2, output_dir=all_config.LOG_OUTPUT_DIR)
    summary_hook = tf.train.SummarySaverHook(save_steps=all_config.SAVE_EVERY_N_STEP,
                                             output_dir=os.path.join(all_config.LOG_OUTPUT_DIR, all_config.NET_NAME),
                                             summary_op=tf.summary.merge_all())

    if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(mode, loss=total_loss,
                                          train_op=train_op,
                                          training_hooks=[meta_hook, summary_hook])

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode, loss=total_loss,
                                          eval_metric_ops={'accuracies': (accuracies, update_accuracies)})
예제 #7
0
    def initialize_Q(self, model_path=None, alpha=None, **kwargs):
        lr = 10**(-2)
        if alpha:
            lr = alpha

        # Input/Output size fot the network
        output_dim = self.num_actions

        self.model = DenseNet(output_dim, **kwargs)
        if model_path is not None:
            print('loading check point %s' % model_path)
            self.model.load_state_dict(torch.load(model_path))
        self.Q = self.model
        if self.use_target:
            self.target_network = self.Q.copy()
        else:
            self.target_network = self.Q

        self.use_cuda = False
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.use_cuda = True
            self.model.to(torch.device('cuda:0'))
            self.device = torch.device('cuda:0')

        self.target_network.to(self.device)
        self.model.to(self.device)

        self.loss = nn.SmoothL1Loss()
        self.max_lr = lr
        self.optimizer = optim.RMSprop(self.model.parameters(), lr)
        if self.schedule:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=10, verbose=True)
        for p in self.model.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -1, 1))
예제 #8
0
def main():
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Pad(4),
        torchvision.transforms.RandomCrop((32, 32)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
    ])
    train_dataset = torchvision.datasets.CIFAR10(
        '/home/qx/project/data/cifar10-data', train=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(
        '/home/qx/project/data/cifar10-data',
        train=False,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ]))
    dataloaders_dict = {
        'train':
        torch.utils.data.DataLoader(train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4),
        'val':
        torch.utils.data.DataLoader(test_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4)
    }

    all_net = {
        #'DenseNet29': DenseNet.densenet29(),
        #'DenseNet45': DenseNet.densenet45(),
        'DenseNet85': DenseNet.densenet85()
    }
    lr = 0.1
    for name, net in all_net.items():
        net = net.to(device)
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=lr,
                                    momentum=0.9,
                                    weight_decay=weight_decay)
        criterion = torch.nn.CrossEntropyLoss()
        net, val_hist, train_hist, best_acc = train_model(
            net, dataloaders_dict, criterion, optimizer, num_epochs)
        # torch.save(net.state_dict(), save_path+name+'.pth')
        val_hist = [h for h in val_hist]
        train_hist = [h for h in train_hist]
        savehis(val_hist, train_hist, save_path + name + '.png', name,
                best_acc)
def main():
    args = get_args()
    logger = get_logger()

    logger.info(vars(args))

    with tf.device("/gpu:0"):
        if args.densenet == True:
            model = DenseNet()
        else:
            model = CnnLstm()

    tfrecord = TFRecord()
    tfrecord.make_iterator(args.tfr_fname, training=False)

    total = sum(1 for _ in tf.python_io.tf_record_iterator(args.tfr_fname)) // tfrecord.batch_size

    saver = tf.train.Saver(tf.global_variables())

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        saver.restore(sess, args.model_fname)

        sess.run(tfrecord.init_op)

        spec = tfrecord.load(sess, training=False)
        predict = model.predict(sess, spec, args.proba)

        progress_bar = tqdm(total=total, desc="[PREDICT]", unit="batch", leave=False)

        while True:
            try:
                spec = tfrecord.load(sess, training=False)

                if args.proba:
                    predict = np.vstack([predict, model.predict(sess, spec, args.proba)])
                else:
                    predict = np.hstack([predict, model.predict(sess, spec, args.proba)])

                progress_bar.update(1)

            except tf.errors.OutOfRangeError:
                break

    make_submission(predict, args.sample_fname, args.output_fname, args.proba)
    logger.info(f"{args.output_fname} is created.")
예제 #10
0
from model import Net, DenseNet, STN_DenseNet

############################# Models #########################################
# If you would like to test DenseNet with STN module, simply replace:
# model = DenseNet(...) with model = STN_DenseNet(...)
#############################################################################
# The best model I can get...achieved ~99.13% on the public leaderboard
## Please refer to aug_large.pth for the pre-trained model
# Even without **offline** preprocesing+augumentation, this model could achieve ~98.6% (my first submission :D)
# Please refer to noaug_large.pth
# This model have about ~10M parameters and consume about 15GB GPU memory during
# training, which could fit into a single TITAN V
model = DenseNet(
    growth_rate=24,  # K
    block_config=(32, 32, 32),  # (L - 4)/6
    num_init_features=48,  # 2 * growth rate
    bn_size=4,
    drop_rate=0.1,
    num_classes=43)

# A smaller model, which could also achieve ~98.3% on the public leaderboard
# with offline preprocessing+augumentation
# Please refer to aug_small.pth
# This model have about 0.9M parameters
# model = DenseNet(growth_rate = 12, # K
#                  block_config = (16,16,16), # (L - 4)/6
#                  num_init_features = 32,
#                  bn_size = 4,
#                  drop_rate = 0.1,
#                  num_classes = 43)
예제 #11
0
class TetrisQLearn:
    def __init__(self, games_state, savename, dirname='logs', **kwargs):

        self.simulator = games_state

        # Q learn basic params
        self.explore_val = 1  # probability to explore v exploit
        self.explore_decay = 0.999  # explore chance is reduced as Q resolves
        self.gamma = 1  # short-term/long-term trade-off param
        self.num_episodes = 500  # number of episodes of simulation to perform
        self.save_weight_freq = 10  # controls how often (in number of episodes) the weights of Q are saved
        self.memory = []
        self._process_mask = []
        self.processed_memory = []  # memory container

        # fitted Q-Learning params
        self.episode_update = 1  # after how many episodes should we update Q?
        self.batch_size = 10  # length of memory replay (in episodes)

        self.schedule = False
        self.refresh_target = 1

        self.renderpath = None

        # let user define each of the params above
        if "gamma" in kwargs:
            self.gamma = kwargs['gamma']
        if 'explore_val' in kwargs:
            self.explore_val = kwargs['explore_val']
        if 'explore_decay' in kwargs:
            self.explore_decay = kwargs['explore_decay']
        if 'num_episodes' in kwargs:
            self.num_episodes = kwargs['num_episodes']
        if 'episode_update' in kwargs:
            self.episode_update = kwargs['episode_update']
        if 'exit_level' in kwargs:
            self.exit_level = kwargs['exit_level']
        if 'exit_window' in kwargs:
            self.exit_window = kwargs['exit_window']
        if 'save_weight_freq' in kwargs:
            self.save_weight_freq = kwargs['save_weight_freq']
        if 'batch_size' in kwargs:
            self.batch_size = kwargs['batch_size']
        if 'episode_update' in kwargs:
            self.episode_update = kwargs['episode_update']
        if 'schedule' in kwargs:
            self.schedule = kwargs['schedule']
        if 'memory_length' in kwargs:
            self.memory_length = kwargs['memory_length']
        if 'refresh_target' in kwargs:
            self.refresh_target = kwargs['refresh_target']
        if 'minibatch_size' in kwargs:
            self.minibatch_size = kwargs['minibatch_size']
        if 'render_path' in kwargs:
            self.renderpath = kwargs['render_path']
        if 'use_target' in kwargs:
            self.use_target = kwargs['use_target']

        # get simulation-specific variables from simulator
        self.num_actions = self.simulator.output_dimension
        self.training_reward = []
        self.savename = savename

        # create text file for training log
        self.logname = os.path.join(dirname, 'training_logs',
                                    savename + '.txt')
        self.reward_logname = os.path.join(dirname, 'reward_logs',
                                           savename + '.txt')
        if not os.path.exists(
                os.path.join(dirname, 'saved_model_weights', savename)):
            os.mkdir(os.path.join(dirname, 'saved_model_weights', savename))
        if self.renderpath and not os.path.exists(
                os.path.join(self.renderpath, self.savename)):
            os.makedirs(os.path.join(self.renderpath, self.savename),
                        exist_ok=True)
        self.renderpath = os.path.join(self.renderpath, self.savename)

        self.weights_folder = os.path.join(dirname, 'saved_model_weights',
                                           savename)
        self.reward_table = os.path.join(dirname, 'reward_logs_extended',
                                         savename + '.csv')
        self.weights_idx = 0

        self.init_log(self.logname)
        self.init_log(self.reward_logname)
        self.init_log(self.reward_table)

        self.write_header = True

    def render_model(self, epoch):
        fig = plt.figure()
        ax = fig.gca(projection='3d')

        print('rendering...')
        tick = time.time()
        axis_extents = self.simulator.board_extents
        ax.set_xlim3d(0, axis_extents[0])
        ax.set_ylim3d(0, axis_extents[1])
        ax.set_zlim3d(0, axis_extents[2])
        demo_game = GameState(board_shape=axis_extents,
                              rewards=self.simulator.rewards)
        self.model.eval()
        path = os.path.join(self.renderpath,
                            self.savename + '-' + str(epoch) + '.gif')
        render.render_from_model(self.model,
                                 fig,
                                 ax,
                                 demo_game,
                                 path,
                                 device=self.device)
        self.model.train()
        print('rendered %s in %.2fs' % (path, time.time() - tick))

    # Logging stuff
    def init_log(self, logname):
        # delete log if old version exists
        if os.path.exists(logname):
            os.remove(logname)

    def update_log(self, logname, update, epoch=None):
        if type(update) == str:
            logfile = open(logname, "a")
            logfile.write(update)
            logfile.close()
        else:
            mod = self.model.state_dict()
            torch.save(
                mod,
                os.path.join(self.weights_folder,
                             self.savename + str(epoch) + '.pth'))
            self.weights_idx += 1

    def log_reward(self, reward_dict):

        keys = reward_dict.keys()

        if self.write_header:
            with open(self.reward_table, 'a') as output_file:
                dict_writer = csv.DictWriter(output_file, keys)
                dict_writer.writeheader()
            self.write_header = False

        with open(self.reward_table, 'a') as output_file:
            dict_writer = csv.DictWriter(output_file, keys)
            dict_writer.writerow(reward_dict)

    # Q Learning Stuff
    def initialize_Q(self, model_path=None, alpha=None, **kwargs):
        lr = 10**(-2)
        if alpha:
            lr = alpha

        # Input/Output size fot the network
        output_dim = self.num_actions

        self.model = DenseNet(output_dim, **kwargs)
        if model_path is not None:
            print('loading check point %s' % model_path)
            self.model.load_state_dict(torch.load(model_path))
        self.Q = self.model
        if self.use_target:
            self.target_network = self.Q.copy()
        else:
            self.target_network = self.Q

        self.use_cuda = False
        self.device = torch.device('cpu')
        if torch.cuda.is_available():
            self.use_cuda = True
            self.model.to(torch.device('cuda:0'))
            self.device = torch.device('cuda:0')

        self.target_network.to(self.device)
        self.model.to(self.device)

        self.loss = nn.SmoothL1Loss()
        self.max_lr = lr
        self.optimizer = optim.RMSprop(self.model.parameters(), lr)
        if self.schedule:
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=10, verbose=True)
        for p in self.model.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -1, 1))

    def _embed_piece(self, piece, embedding_size=(4, 4, 4)):
        out = np.zeros((piece.shape[0], piece.shape[1], *embedding_size))
        for coord in product(*[range(x) for x in piece.shape]):
            out[coord] = piece[coord]
        return torch.from_numpy(out).to(self.device)

    def memory_replay(self):
        # process transitions using target network
        q_vals = []
        states = []
        pieces = []
        locats = []
        episode_loss = 0

        total_processed = 0

        tick = time.time()
        for i in range(len(self.memory)):
            episode_data = self.memory[i]
            if self.processed_memory[i] is None:
                self.processed_memory[i] = [None] * len(episode_data)

            for j in range(len(episode_data)):
                # process the sample and put it into the processed_memory
                sample = episode_data[j]

                state, piece, locat = sample[0]
                next_state, next_piece, next_locat = sample[1]
                action = sample[2]
                reward = sample[3]
                done = sample[4]

                if self.processed_memory[i][j] is None:
                    q = reward

                    # preprocess q using target network
                    if not done:
                        next_state = torch.tensor(next_state).to(self.device)
                        next_piece = torch.tensor(next_piece).to(self.device)
                        next_locat = torch.tensor(next_locat).to(self.device)

                        qs = self.target_network(next_state, next_piece,
                                                 next_locat)  #should be target
                        q += self.gamma * torch.max(qs)

                    state = torch.tensor(state).to(self.device)
                    piece = torch.tensor(piece).to(self.device)
                    locat = torch.tensor(locat).to(self.device)

                    # q is our experientially validated move score. Anchor it on our prediction vector
                    q_update = self.target_network(
                        state, piece, locat).squeeze(0)  # should be target
                    q_update[action] = q
                    processed = q_update.detach().cpu().numpy()
                    q_vals = q_vals + [processed]
                    self.processed_memory[i][j] = processed
                    total_processed += 1

                    ## WE HAVE NOW PREPROCESSED W TARGET NETWORK

                    # clear up the vram
                    state = state.cpu().squeeze(0).numpy()
                    piece = piece.float().squeeze(0).cpu().numpy()
                    locat = locat.float().cpu().numpy()

                else:
                    q_vals = q_vals + [self.processed_memory[i][j]]

                # its goofy but it will work
                if state.ndim > 4:
                    state = state.squeeze(0)
                assert state.ndim == 4

                if piece.ndim > 4:
                    piece = piece.squeeze(0)
                assert piece.ndim == 4

                if locat.ndim > 2:
                    locat = locat.squeeze(0)
                assert locat.ndim == 2

                states.append(state)
                pieces.append(piece)
                locats.append(locat)

        elapsed_time = time.time() - tick
        print('process time: %.2f' % elapsed_time)
        print('total processed: %d (%.2f/s)' %
              (total_processed, total_processed / elapsed_time))

        # take descent step
        memory = MemoryDset(states, pieces, locats, q_vals)

        if self.minibatch_size > 0:
            ids = random.sample(range(len(memory)),
                                min(self.minibatch_size,
                                    len(memory) - 1))
            memory = torch.utils.data.Subset(memory, ids)
        dataloader = DataLoader(memory,
                                batch_size=self.batch_size,
                                shuffle=True)

        tick = time.time()
        for s, p, l, q in dataloader:
            self.optimizer.zero_grad()

            out = self.Q(s, p, l)
            loss = self.loss(out, q)
            loss.backward()
            self.optimizer.step()
            episode_loss += loss.item()
            s.detach_()
        print('fit time: %.2f' % (time.time() - tick))

        if self.schedule:
            self.scheduler.step(episode_loss)
        return episode_loss / len(self.memory) / len(dataloader)

    def update_target(self):
        if self.use_target:
            self.target_network = self.Q.copy()
            self.target_network.to(self.device)
        self.processed_memory = [None] * len(self.processed_memory)

    def update_memory(self, episode_data):
        # add most recent trial data to memory
        self.memory.append(episode_data)
        self.processed_memory.append(None)

        # clip memory if it gets too long
        num_episodes = len(self.memory)
        if num_episodes >= self.memory_length:
            num_delete = num_episodes - self.memory_length
            self.memory[:num_delete] = []
            self.processed_memory[:num_delete] = []

    def make_torch(self, array):
        tens = torch.from_numpy(array.copy())
        tens = tens.float()
        tens = tens.unsqueeze(0)
        tens = tens.unsqueeze(0)
        return tens.to(self.device)

    # choose next action
    def choose_action(self, state, piece, location):
        # pick action at random
        p = np.random.rand(1)

        action = np.random.randint(len(self.simulator.action_space))

        # pick action based on exploiting
        qs = self.Q(state, piece, location)

        if p > self.explore_val:
            action = torch.argmax(qs)
        return action

    def renormalize_vec(self, tensor, idx):
        tensor = tensor.squeeze(0)
        tensor[idx] = 0
        sum = torch.sum(tensor)
        return tensor / sum

    def run(self):

        print("num_episodes: %s" % self.num_episodes)

        # start main Q-learning loop
        for n in range(self.num_episodes):
            # pick this episode's starting position
            state = self.simulator.reset()
            total_episode_reward = 0
            done = False

            # get our exploit parameter for this episode
            if self.explore_val > 0.01 and (n % self.episode_update) == 0:
                old_explore = self.explore_val
                self.explore_val *= self.explore_decay
                if old_explore - self.explore_val > 0.25:
                    for param in self.optimizer.param_groups:
                        print('resetting to max learning rate: %s' %
                              self.max_lr)
                        param['lr'] = self.max_lr

            # run episode
            step = 0
            episode_data = []
            ep_start_time = time.time()
            ep_rew_dict = None

            while done is False:

                # choose next action
                board = self.make_torch(state.board)
                piece = self.make_torch(state.current.matrix)
                loc = torch.tensor(state.current.location).unsqueeze(0).to(
                    self.device)
                action = self.choose_action(board, piece, loc)

                # transition to next state, get associated reward
                next_state, reward_dict, done = self.simulator(
                    self.simulator.action_space[action])
                if ep_rew_dict is None:
                    ep_rew_dict = reward_dict
                else:
                    ep_rew_dict = add_reward_dicts(ep_rew_dict, reward_dict)
                next_board = self.make_torch(next_state.board)
                next_piece = self.make_torch(next_state.current.matrix)
                next_locat = torch.tensor(
                    next_state.current.location).unsqueeze(0)

                reward = reward_dict['total']

                # move board back to cpu to clear up vram
                board = board.cpu().numpy()
                next_board = next_board.cpu().numpy()
                piece = piece.cpu().numpy()
                location = loc.cpu().numpy()
                next_piece = next_piece.cpu().numpy()
                next_locat = next_locat.cpu().numpy()

                # store data for transition after episode ends
                episode_data.append([(board, piece, location),
                                     (next_board, next_piece, next_locat),
                                     action, reward, done])

                # update total reward from this episode
                total_episode_reward += reward
                state = copy.deepcopy(next_state)
                step += 1

            # update memory with this episode's data
            self.update_memory(episode_data)

            LOSS_SCALING = 100

            # update the target network
            if self.use_target:
                if np.mod(n, self.refresh_target) == 0:
                    self.update_target()
            else:
                self.update_target()

            # train model
            episode_loss = 0
            if np.mod(n, self.episode_update) == 0:
                episode_loss = self.memory_replay() * LOSS_SCALING

            # update episode reward greater than exit_level, add to counter
            exit_ave = total_episode_reward
            if n >= self.exit_window:
                exit_ave = np.sum(
                    np.array(self.training_reward[-self.exit_window:])
                ) / self.exit_window

            # print out updates
            # I abuse the f**k out of this variable. Watch how many different values it assumes and how
            # important the order of operations is. I do this because I hate myself.
            update = 'episode ' + str(n + 1) + ' of ' + str(
                self.num_episodes
            ) + ' complete, ' + 'loss x%s = ' % LOSS_SCALING + str(
                np.round(episode_loss, 3)) + ' explore val = ' + str(
                    np.round(self.explore_val, 3)
                ) + ', episode reward = ' + str(
                    np.round(
                        total_episode_reward, 1)) + ', ave reward = ' + str(
                            np.round(exit_ave, 3)) + ', episode_time = ' + str(
                                np.round(time.time() - ep_start_time, 3))

            self.update_log(self.logname, update + '\n')

            if np.mod(n, self.episode_update) == 0:
                print(colored(update, 'red'))
            else:
                print(update)

            # save latest weights from this episode
            if np.mod(n, self.save_weight_freq) == 0:
                update = self.model.state_dict()
                self.update_log(self.weights_folder, update, epoch=n)

            if self.renderpath and n % self.save_weight_freq == self.save_weight_freq - 1:
                self.render_model(n + 1)

            update = str(total_episode_reward) + '\n'
            self.update_log(self.reward_logname, update)
            self.log_reward(ep_rew_dict)

            # store this episode's computation time and training reward history
            self.training_reward.append(total_episode_reward)

        update = 'q-learning algorithm complete'
        self.update_log(self.logname, update + '\n')
        print(update)
예제 #12
0
    
    if 'outputs' not in os.listdir(os.curdir):
        os.mkdir('outputs')
    
    checkpoint_filepath = 'outputs/' + 'model-{epoch:03d}.h5'
    
    
    if 'model-last_epoch.h5' in os.listdir('outputs/'):
        print ('last model loaded')
        model= load_model('outputs/model-last_epoch.h5')
        
        

    else:
        print('created a new model instead')
        model = DenseNet(input_shape= (r,c,1), dense_blocks=5, dense_layers=-1, growth_rate=8, dropout_rate=0.2,
             bottleneck=True, compression=1.0, weight_decay=1e-4, depth=40)


    # training parameters
    adamOpt = Adam(lr=0.0001)
    model.compile(loss='mean_squared_error', optimizer=adamOpt, metrics=['mae', 'mse'])
    model.summary(line_length=200)
    model.save("test")

#    log_filename = 'outputs/' + 'landmarks' +'_results.csv'
#
#    csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True)
#
#    
#
#    checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
예제 #13
0
def main():
    # define options
    parser = argparse.ArgumentParser(
        description='Training script of DenseNet on CIFAR-10 dataset')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=300,
                        help='Number of epochs to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--initmodel',
                        help='Initialize the model from given file')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Initialize the trainer from given file')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Output directory')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=64,
                        help='Validation minibatch size')
    parser.add_argument('--numlayers',
                        '-L',
                        type=int,
                        default=40,
                        help='Number of layers')
    parser.add_argument('--growth',
                        '-G',
                        type=int,
                        default=12,
                        help='Growth rate parameter')
    parser.add_argument('--dropout',
                        '-D',
                        type=float,
                        default=0.2,
                        help='Dropout ratio')
    parser.add_argument('--dataset',
                        type=str,
                        default='C10',
                        choices=('C10', 'C10+', 'C100', 'C100+'),
                        help='Dataset used for training (Default is C10)')
    args = parser.parse_args()

    # load dataset
    if args.dataset == 'C10':
        train, test = dataset.get_C10()
    elif args.dataset == 'C10+':
        train, test = dataset.get_C10_plus()
    elif args.dataset == 'C100':
        train, test = dataset.get_C100()
    elif args.dataset == 'C100+':
        train, test = dataset.get_C100_plus()

    train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize)
    test_iter = chainer.iterators.MultiprocessIterator(test,
                                                       args.batchsize,
                                                       repeat=False,
                                                       shuffle=False)

    # setup model
    model = L.Classifier(
        DenseNet(args.numlayers, args.growth, 16, args.dropout, 10))

    if args.initmodel:
        print('Load model from', args.initmodel)
        chainer.serializers.load_npz(args.initmodel, model)
    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        model.to_gpu()

    # setup optimizer
    optimizer = chainer.optimizers.NesterovAG(lr=0.1, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(1e-4))

    # setup trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(Evaluator(test_iter, model, device=args.gpu))
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=(10, 'epoch'))
    trainer.extend(
        extensions.snapshot_object(model, 'model_{.updater.epoch}.npz'))
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy'
        ]))
    trainer.extend(extensions.ProgressBar())

    # devide lr by 10 at 0.5, 0.75 fraction of total number of training epochs
    iter_per_epoch = math.ceil(len(train) / args.batchsize)
    n_iter1 = int(args.epoch * 0.5 * iter_per_epoch)
    n_iter2 = int(args.epoch * 0.75 * iter_per_epoch)
    shifts = [(n_iter1, 0.01), (n_iter2, 0.001)]
    trainer.extend(StepShift('lr', shifts, optimizer))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # start training
    trainer.run()
예제 #14
0
class Solver(object):

    DEFAULTS = {}

    def __init__(self, version, data_loader, config):
        """
        Initializes a Solver object
        """

        # data loader
        self.__dict__.update(Solver.DEFAULTS, **config)
        self.version = version
        self.data_loader = data_loader

        self.build_model()

        # TODO: build tensorboard

        # start with a pre-trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def build_model(self):
        """
        Instantiates the model, loss criterion, and optimizer
        """

        # instantiate model
        self.model = DenseNet(config=self.config,
                              channels=self.input_channels,
                              class_count=self.class_count,
                              num_features=self.num_features,
                              compress_factor=self.compress_factor,
                              expand_factor=self.expand_factor,
                              growth_rate=self.growth_rate)

        # instantiate loss criterion
        self.criterion = nn.CrossEntropyLoss()

        # instantiate optimizer
        self.optimizer = optim.SGD(params=self.model.parameters(),
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay,
                                   nesterov=True)

        # print network
        self.print_network(self.model, 'DenseNet')

        # use gpu if enabled
        if torch.cuda.is_available() and self.use_gpu:
            self.model.cuda()
            self.criterion.cuda()

    def print_network(self, model, name):
        """
        Prints the structure of the network and the total number of parameters
        """
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def load_pretrained_model(self):
        """
        loads a pre-trained model from a .pth file
        """
        self.model.load_state_dict(torch.load(os.path.join(
            self.model_save_path, '{}.pth'.format(self.pretrained_model))))
        print('loaded trained model ver {}'.format(self.pretrained_model))

    def print_loss_log(self, start_time, iters_per_epoch, e, i, loss):
        """
        Prints the loss and elapsed time for each epoch
        """
        total_iter = self.num_epochs * iters_per_epoch
        cur_iter = e * iters_per_epoch + i

        elapsed = time.time() - start_time
        total_time = (total_iter - cur_iter) * elapsed / (cur_iter + 1)
        epoch_time = (iters_per_epoch - i) * elapsed / (cur_iter + 1)

        epoch_time = str(datetime.timedelta(seconds=epoch_time))
        total_time = str(datetime.timedelta(seconds=total_time))
        elapsed = str(datetime.timedelta(seconds=elapsed))

        log = "Elapsed {}/{} -- {}, Epoch [{}/{}], Iter [{}/{}], " \
              "loss: {:.4f}".format(elapsed,
                                    epoch_time,
                                    total_time,
                                    e + 1,
                                    self.num_epochs,
                                    i + 1,
                                    iters_per_epoch,
                                    loss)

        # TODO: add tensorboard

        print(log)

    def save_model(self, e):
        """
        Saves a model per e epoch
        """
        path = os.path.join(
            self.model_save_path,
            '{}/{}.pth'.format(self.version, e + 1)
        )

        torch.save(self.model.state_dict(), path)

    def model_step(self, images, labels):
        """
        A step for each iteration
        """

        # set model in training mode
        self.model.train()

        # empty the gradients of the model through the optimizer
        self.optimizer.zero_grad()

        # forward pass
        output = self.model(images)

        # compute loss
        loss = self.criterion(output, labels.squeeze())

        # compute gradients using back propagation
        loss.backward()

        # update parameters
        self.optimizer.step()

        # return loss
        return loss

    def train(self):
        """
        Training process
        """
        self.losses = []
        self.top_1_acc = []
        self.top_5_acc = []

        iters_per_epoch = len(self.data_loader)

        # start with a trained model if exists
        if self.pretrained_model:
            start = int(self.pretrained_model.split('/')[-1])
        else:
            start = 0

        # start training
        start_time = time.time()
        for e in range(start, self.num_epochs):
            for i, (images, labels) in enumerate(tqdm(self.data_loader)):
                images = to_var(images, self.use_gpu)
                labels = to_var(labels, self.use_gpu)

                loss = self.model_step(images, labels)

            # print out loss log
            if (e + 1) % self.loss_log_step == 0:
                self.print_loss_log(start_time, iters_per_epoch, e, i, loss)
                self.losses.append((e, loss))

            # save model
            if (e + 1) % self.model_save_step == 0:
                self.save_model(e)

            # evaluate on train dataset
            if (e + 1) % self.train_eval_step == 0:
                top_1_acc, top_5_acc = self.train_evaluate(e)
                self.top_1_acc.append((e, top_1_acc))
                self.top_5_acc.append((e, top_5_acc))

        # print losses
        print('\n--Losses--')
        for e, loss in self.losses:
            print(e, '{:.4f}'.format(loss))

        # print top_1_acc
        print('\n--Top 1 accuracy--')
        for e, acc in self.top_1_acc:
            print(e, '{:.4f}'.format(acc))

        # print top_5_acc
        print('\n--Top 5 accuracy--')
        for e, acc in self.top_5_acc:
            print(e, '{:.4f}'.format(acc))

    def eval(self, data_loader):
        """
        Returns the count of top 1 and top 5 predictions
        """

        # set the model to eval mode
        self.model.eval()

        top_1_correct = 0
        top_5_correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in data_loader:

                images = to_var(images, self.use_gpu)
                labels = to_var(labels, self.use_gpu)

                output = self.model(images)
                total += labels.size()[0]

                # top 1
                # get the max for each instance in the batch
                _, top_1_output = torch.max(output.data, dim=1)

                top_1_correct += torch.sum(torch.eq(labels.squeeze(),
                                                    top_1_output))

                # top 5
                _, top_5_output = torch.topk(output.data, k=5, dim=1)
                for i, label in enumerate(labels):
                    if label in top_5_output[i]:
                        top_5_correct += 1

        return top_1_correct.item(), top_5_correct, total

    def train_evaluate(self, e):
        """
        Evaluates the performance of the model using the train dataset
        """
        top_1_correct, top_5_correct, total = self.eval(self.data_loader)
        log = "Epoch [{}/{}]--top_1_acc: {:.4f}--top_5_acc: {:.4f}".format(
            e + 1,
            self.num_epochs,
            top_1_correct / total,
            top_5_correct / total
        )
        print(log)
        return top_1_correct / total, top_5_correct / total

    def test(self):
        """
        Evaluates the performance of the model using the test dataset
        """
        top_1_correct, top_5_correct, total = self.eval(self.data_loader)
        log = "top_1_acc: {:.4f}--top_5_acc: {:.4f}".format(
            top_1_correct / total,
            top_5_correct / total
        )
        print(log)
예제 #15
0
from model import DenseNet

parser = argparse.ArgumentParser(description='PyTorch GTSRB evaluation script')
parser.add_argument('--data', type=str, default='data', metavar='D',
                    help="folder where data is located. train_data.zip and test_data.zip need to be found in the folder")
parser.add_argument('--model', type=str, metavar='M',
                    help="the model file to be evaluated. Usually it is of the form model_X.pth")
parser.add_argument('--outfile', type=str, default='gtsrb_kaggle.csv', metavar='D',
                    help="name of the output csv file")

args = parser.parse_args()

state_dict = torch.load(args.model)
model = DenseNet(growth_rate = 24, # K
                 block_config = (32,32,32), # (L - 4)/6
                 num_init_features = 48, # 2 * growth rate
                 bn_size = 4,
                 drop_rate = 0,
                 num_classes = 43)

model.load_state_dict(state_dict)
model = model.cuda()
model.eval()

from data import val_transforms

test_dir = args.data + '/test_sharp_images'

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
예제 #16
0
    if 'outputs' not in os.listdir(os.curdir):
        os.mkdir('outputs')

    checkpoint_filepath = 'outputs/' + 'model-{epoch:03d}.h5'

    if 'model-last_epoch.h5' in os.listdir('outputs/'):
        print('last model loaded')
        model = load_model('outputs/model-last_epoch.h5')

    else:
        print('created a new model instead')
        model = DenseNet(input_shape=(r, c, 1),
                         dense_blocks=5,
                         dense_layers=-1,
                         growth_rate=8,
                         dropout_rate=0.2,
                         bottleneck=True,
                         compression=1.0,
                         weight_decay=1e-4,
                         depth=40)

    # training parameters
    adamOpt = Adam(lr=0.0001)
    model.compile(loss='mean_squared_error',
                  optimizer=adamOpt,
                  metrics=['mae', 'mse'])
    model.summary(line_length=200)

    log_filename = 'outputs/' + 'landmarks' + '_results.csv'

    csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True)
예제 #17
0
    # Create the input data pipeline
    logging.info("Loading the datasets...")

    trainloader, testloader = data()
    train_dl = trainloader
    dev_dl = testloader
    logging.info("- done.")
    if "distill" in params.model_version:
        if params.model_version == "cnn_distill":
            model = CNN.CNN().to(device)
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            loss_fn_kd = CNN.loss_fn_kd
            metrics = CNN.metrics
        elif params.model_version == 'densenet':
            model = DenseNet.DenseNet().to(device)
            optimizer = optim.SGD(model.parameters(),
                                  lr=params.learning_rate,
                                  momentum=0.9,
                                  weight_decay=1e-4)
            loss_fn_kd = CNN.loss_fn_kd
            metrics = DenseNet.metrics

        if params.teacher == "densenet":
            teacher_model = DenseNet.DenseNet()
            teacher_checkpoint = 'experiments/base_densenet/best.pth.tar'
            teacher_model = teacher_model.to(device)
        # elif params.teacher == "resnet18":
        #     teacher_model = resnet.ResNet18()
        #     teacher_checkpoint = 'experiments/base_resnet18/best.pth.tar'
        #     teacher_model = teacher_model.cuda() if params.cuda else teacher_model
예제 #18
0
파일: eval.py 프로젝트: mthreet/drlnd-p1
from agent import Agent
from model import DenseNet

# Load the environment
env = UnityEnvironment(file_name='Banana_Linux/Banana.x86_64')
# Get the default brain
brain_name = env.brain_names[0]
brain = env.brains[brain_name]

# Get basic env parmaeters
env_info = env.reset(train_mode=False)[brain_name]
state_size = len(env_info.vector_observations[0])
action_size = brain.vector_action_space_size

# Load the Agent
net = DenseNet(state_size, action_size)
net.load_state_dict(torch.load('checkpoint.pth'))
agent = Agent(state_size, action_size, net)

env_info = env.reset(train_mode=False)[brain_name]  # reset the environment
state = env_info.vector_observations[0]  # get the current state
score = 0  # initialize the score
while True:
    action = agent.act(state)  # select an action
    env_info = env.step(action)[
        brain_name]  # send the action to the environment
    next_state = env_info.vector_observations[0]  # get the next state
    reward = env_info.rewards[0]  # get the reward
    done = env_info.local_done[0]  # see if episode has finished
    score += reward  # update the score
    state = next_state  # roll over the state to next time step
예제 #19
0
def BackBone_Unet(backbone_name):
    up_parm_dict = {
        'resnet18': [512, 256, 128, 64, 64, 64, 64, 3],
        'resnet34': [512, 256, 128, 64, 64, 64, 64, 3],
        'resnet50': [2048, 1024, 512, 256, 128, 64, 64, 3],
        'resnet101': [2048, 1024, 512, 256, 128, 64, 64, 3],
        'resnet152': [2048, 1024, 512, 256, 128, 64, 64, 3],
        'densenet121': [1024, 1024, 512, 256, 128, 64, 64, 3],
        'densenet161': [2204, 2104, 752, 352, 128, 64, 64, 3],
        'densenet201': [1920, 1792, 512, 256, 128, 64, 64, 3],
        'densenet169': [1664, 1280, 512, 256, 128, 64, 64, 3],
        'efficientnet-b0': [1280, 112, 40, 24, 16, 16, 64, 3],
        'efficientnet-b1': [1280, 112, 40, 24, 16, 16, 64, 3],
        'efficientnet-b2': [1280, 120, 48, 24, 16, 16, 64, 3],
        'efficientnet-b3': [1280, 136, 48, 32, 24, 24, 64, 3],
        'efficientnet-b4': [1280, 160, 56, 32, 24, 24, 64, 3],
        'efficientnet-b5': [1280, 176, 64, 40, 24, 24, 64, 3],
        'efficientnet-b6': [1280, 200, 72, 40, 32, 32, 64, 3],
        'efficientnet-b7': [1280, 224, 80, 48, 32, 32, 64, 3]
    }

    efficient_param = {
        # 'efficientnet type': (width_coef, depth_coef, resolution, dropout_rate)
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b1': (1.0, 1.1, 224, 0.2),
        'efficientnet-b2': (1.1, 1.2, 224, 0.3),
        'efficientnet-b3': (1.2, 1.4, 224, 0.3),
        'efficientnet-b4': (1.4, 1.8, 224, 0.4),
        'efficientnet-b5': (1.6, 2.2, 224, 0.4),
        'efficientnet-b6': (1.8, 2.6, 224, 0.5),
        'efficientnet-b7': (2.0, 3.1, 224, 0.5)
    }

    if backbone_name[0] == 'r':
        if backbone_name[-2:] == '18':
            model = ResNet.ResNet18()
        if backbone_name[-2:] == '34':
            model = ResNet.ResNet34()
        if backbone_name[-2:] == '50':
            model = ResNet.ResNet50()
        if backbone_name[-2:] == '01':
            model = ResNet.ResNet101()
        if backbone_name[-2:] == '52':
            model = ResNet.ResNet152()

        net = Res_Unet(model=model, up_parm=up_parm_dict[backbone_name])

    elif backbone_name[0] == 'd':
        if backbone_name[-2:] == '21':
            model = DenseNet.DenseNet121(seon=False)
        if backbone_name[-2:] == '61':
            model = DenseNet.DenseNet161(seon=False)
        if backbone_name[-2:] == '01':
            model = DenseNet.DenseNet201(seon=False)
        if backbone_name[-2:] == '69':
            model = DenseNet.DenseNet169(seon=False)

        net = Dense_Unet(model=model, up_parm=up_parm_dict[backbone_name])
    elif backbone_name[0] == 'e':
        param = efficient_param[backbone_name]
        model = EfficientNet.EfficientNet(param)
        net = Efficient_Unet(model=model, up_parm=up_parm_dict[backbone_name])

    return net
예제 #20
0
def demo(save, depth=100, growth_rate=32, efficient=False, valid_size=420,
         n_epochs=50, batch_size=30, seed=None):
    """
    A demo to show off training of efficient DenseNets.
    Trains and evaluates a DenseNet-BC on 铝型材(自制数据集接口).

    Args:
        data (str) - path to directory where data should be loaded from/downloaded
            (default $DATA_DIR)
        save (str) - path to save the model to (default /tmp)

        depth (int) - depth of the network (number of convolution layers) (default 40)
        growth_rate (int) - number of features added per DenseNet layer (default 12)
        efficient (bool) - use the memory efficient implementation? (default True)

        valid_size (int) - size of validation set
        n_epochs (int) - number of epochs for training (default 300)
        batch_size (int) - size of minibatch (default 256)
        seed (int) - manually set the random seed (default None)
    """

    # Get densenet configuration
    if (depth - 4) % 3:
        raise Exception('Invalid depth')
    block_config = [(depth - 4) // 6 for _ in range(3)]

    # Data transforms
    mean = [0.5071, 0.4867, 0.4408]
    stdv = [0.2675, 0.2565, 0.2761]
    train_transforms = {'train':transforms.Compose([
        transforms.RandomCrop(224, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=stdv)
    ])}


    # Datasets
    train_set = {'train':customData(txt_path=('train.txt'),
                           data_transforms=train_transforms,
                           dataset='train')}

    # Models
    model = DenseNet(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=12,
        small_inputs=False,
        efficient=efficient 
        )
    print(model)

    # Make save directory
    if not os.path.exists(save):
        os.makedirs(save)
    if not os.path.isdir(save):
        raise Exception('%s is not a dir' % save)

    # Train the model
    train(model=model, train_set=train_set['train'], save=save,
          valid_size=valid_size, n_epochs=n_epochs, batch_size=batch_size, seed=seed)
    print('Done!')
예제 #21
0
if __name__ == "__main__":

    ### prediction..

    image_path = 'try.jpg'
    im = Image.open(image_path)
    im = np.array(im)
    im = np.delete(im, [1, 2], axis=2)
    im = np.array(im) / 255.0
    im = np.expand_dims(im, axis=0)
    print(im.shape)
    model = DenseNet(dense_blocks=5,
                     dense_layers=-1,
                     growth_rate=8,
                     dropout_rate=0.2,
                     bottleneck=True,
                     compression=1.0,
                     weight_decay=1e-4,
                     depth=40)
    model.load_weights("outputs/model-230.h5")

    lmarks = model.predict(im)
    print(lmarks)
    lmarks = lmarks[0]
    lmarks[0:8:2] = lmarks[0:8:2] * im.shape[2]
    lmarks[1:8:2] = lmarks[1:8:2] * im.shape[1]
    print(lmarks)

    #print(lmarks)
    im = im[0] * 255
    im = np.squeeze(im, axis=(2, ))
예제 #22
0
from model import DenseNet
from torchsummary import summary

model = DenseNet()
summary(model, (3, 224, 224))
예제 #23
0
        else:

            adversarial_path = args.attack_path

    else:

        model = None

        if args.model == 'UNet':
            model = UNet(in_channels=n_channels, n_classes=n_classes)

        elif args.model == 'SegNet':
            model = SegNet(in_channels=n_channels, n_classes=n_classes)

        elif args.model == 'DenseNet':
            model = DenseNet(in_channels=n_channels, n_classes=n_classes)

        else:
            print("wrong model : must be UNet, SegNet, or DenseNet")
            raise SystemExit

        summary(model,
                input_size=(n_channels, args.height, args.width),
                device='cpu')

        model.load_state_dict(torch.load(args.model_path))

        adversarial_examples = DAG_Attack(model, test_dataset, args)

        if args.attack_path is None:
예제 #24
0
class Trainer(object):
    """
    The Trainer class encapsulates all the logic necessary for 
    training the DenseNet model. It use SGD to update the weights 
    of the model given hyperparameters constraints provided by the 
    user in the config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Params
        ------
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
        else:
            self.test_loader = data_loader

        # network params
        self.num_blocks = config.num_blocks
        self.num_layers_total = config.num_layers_total
        self.growth_rate = config.growth_rate
        self.bottleneck = config.bottleneck
        self.theta = config.compression

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.best_valid_acc = 0.
        self.init_lr = config.init_lr
        self.lr = self.init_lr
        self.is_decay = True
        self.momentum = config.momentum
        self.weight_decay = config.weight_decay
        self.dropout_rate = config.dropout_rate
        if config.lr_sched == '':
            self.is_decay = False
        else:
            self.lr_decay = [float(x) for x in config.lr_sched.split(',')]

        # other params
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.num_gpu = config.num_gpu
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.dataset = config.dataset
        if self.dataset == 'cifar10':
            self.num_classes = 10
        elif self.dataset == 'cifar100':
            self.num_classes = 100
        else:
            self.num_classes = 1000

        # build densenet model
        self.model = DenseNet(self.num_blocks, self.num_layers_total,
            self.growth_rate, self.num_classes, self.bottleneck, 
                self.dropout_rate, self.theta)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # define loss and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.init_lr,
                momentum=self.momentum, weight_decay=self.weight_decay)

        if self.num_gpu > 0:
            self.model.cuda()
            self.criterion.cuda()

        # finally configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.get_model_name()
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

    def train(self):
        """
        Train the model on the training set. 

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # switch to train mode for dropout
        self.model.train()

        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        for epoch in trange(self.start_epoch, self.epochs):
            
            # decay learning rate
            if self.decay:
                self.anneal_learning_rate(epoch)

            # train for 1 epoch
            self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_acc = self.validate(epoch)

            is_best = valid_acc > self.best_valid_acc
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'best_valid_acc': self.best_valid_acc}, is_best)

    def test(self):
        """
        Test the model on the held-out test data. 

        This function should only be called at the very
        end once the model has finished training.
        """
        # switch to test mode for dropout
        self.model.eval()

        accs = AverageMeter()
        batch_time = AverageMeter()

        # load the best checkpoint
        self.load_checkpoint(best=True)

        tic = time.time()
        for i, (image, target) in enumerate(self.test_loader):
            if self.num_gpu > 0:
                image = image.cuda()
                target = target.cuda(async=True)
            input_var = torch.autograd.Variable(image)
            target_var = torch.autograd.Variable(target)

            # forward pass
            output = self.model(input_var)

            # compute loss & accuracy 
            acc = self.accuracy(output.data, target)
            accs.update(acc, image.size()[0])

            # measure elapsed time
            toc = time.time()
            batch_time.update(toc-tic)

            # print to screen
            if i % self.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Test Acc: {acc.val:.3f} ({acc.avg:.3f})'.format(
                        i, len(self.test_loader), batch_time=batch_time,
                        acc=accs))

        print('[*] Test Acc: {acc.avg:.3f}'.format(acc=accs))

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set. 

        An epoch corresponds to one full pass through the entire 
        training set in successive mini-batches. 

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        for i, (image, target) in enumerate(self.train_loader):
            if self.num_gpu > 0:
                image = image.cuda()
                target = target.cuda(async=True)
            input_var = torch.autograd.Variable(image)
            target_var = torch.autograd.Variable(target)

            # forward pass
            output = self.model(input_var)

            # compute loss & accuracy 
            loss = self.criterion(output, target_var)
            acc = self.accuracy(output.data, target)
            losses.update(loss.data[0], image.size()[0])
            accs.update(acc, image.size()[0])

            # compute gradients and update SGD
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            toc = time.time()
            batch_time.update(toc-tic)

            # print to screen
            if i % self.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Train Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Train Acc: {acc.val:.3f} ({acc.avg:.3f})'.format(
                        epoch, i, len(self.train_loader), batch_time=batch_time,
                        loss=losses, acc=accs))

        # log to tensorboard
        if self.use_tensorboard:
            log_value('train_loss', losses.avg, epoch)
            log_value('train_acc', accs.avg, epoch)


    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        for i, (image, target) in enumerate(self.valid_loader):
            if self.num_gpu > 0:
                image = image.cuda()
                target = target.cuda(async=True)
            input_var = torch.autograd.Variable(image)
            target_var = torch.autograd.Variable(target)

            # forward pass
            output = self.model(input_var)

            # compute loss & accuracy 
            loss = self.criterion(output, target_var)
            acc = self.accuracy(output.data, target)
            losses.update(loss.data[0], image.size()[0])
            accs.update(acc, image.size()[0])

            # measure elapsed time
            toc = time.time()
            batch_time.update(toc-tic)

            # print to screen
            if i % self.print_freq == 0:
                print('Valid: [{0}/{1}]\t'
                    'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Valid Loss: {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Valid Acc: {acc.val:.3f} ({acc.avg:.3f})'.format(
                        i, len(self.valid_loader), batch_time=batch_time,
                        loss=losses, acc=accs))

        print('[*] Valid Acc: {acc.avg:.3f}'.format(acc=accs))

        # log to tensorboard
        if self.use_tensorboard:
            log_value('val_loss', losses.avg, epoch)
            log_value('val_acc', accs.avg, epoch)

        return accs.avg

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated 
        on the test data.

        Furthermore, the model with the highest accuracy is saved as
        with a special name.
        """
        print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.get_model_name() + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.get_model_name() + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, 
                os.path.join(self.ckpt_dir, filename))
            print("[*] ==== Best Valid Acc Achieved ====")

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in 
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.get_model_name() + '_ckpt.pth.tar'
        if best:
            filename = self.get_model_name() + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['state_dict'])
        
        print("[*] Loaded {} checkpoint @ epoch {} with best valid acc of {:.3f}".format(
                    filename, ckpt['epoch'], ckpt['best_valid_acc']))

    def anneal_learning_rate(self, epoch):
        """
        This function decays the learning rate at 2 instances.

        - The initial learning rate is divided by 10 at
          t1*epochs.
        - It is further divided by 10 at t2*epochs. 

        t1 and t2 are floats specified by the user. The default
        values used by the authors of the paper are 0.5 and 0.75.
        """
        sched1 = int(self.lr_sched[0] * self.epochs)
        sched2 = int(self.lr_sched[1] * self.epochs)

        self.lr = self.init_lr * (0.1 ** (epoch // sched1)) \
                               * (0.1 ** (epoch // sched2))

        # log to tensorboard
        if self.use_tensorboard:
            log_value('learning_rate', self.lr, epoch)

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def get_model_name(self):
        """
        Returns the name of the model based on the configuration
        parameters.

        The name will take the form DenseNet-X-Y-Z where:

        - X: total number of layers specified by `config.total_num_layers`.
        - Y: can be BC or an empty string specified by `config.bottleneck`.
        - Z: name of the dataset specified by `config.dataset`.

        For example, given 169 layers with bottleneck on CIFAR-10, this 
        function will output `DenseNet-BC-169-cifar10`.
        """
        if self.bottleneck:
            return 'DenseNet-BC-{}-{}'.format(self.num_layers_total,
                self.dataset)
        return 'DenseNet-{}-{}'.format(self.num_layers_total,
            self.dataset)

    def accuracy(self, predicted, ground_truth):
        """
        Utility function for calculating the accuracy of the model.

        Params
        ------
        - predicted: (torch.FloatTensor)
        - ground_truth: (torch.LongTensor)

        Returns
        -------
        - acc: (float) % accuracy.
        """
        predicted = torch.max(predicted, 1)[1]
        total = len(ground_truth)
        correct = (predicted == ground_truth).sum()
        acc = 100 * (correct / total)
        return acc
예제 #25
0
            f.write('%03d,%0.6f,%0.6f,%0.5f,%0.5f,\n' % (
                (epoch + 1),
                train_loss,
                train_error,
                valid_loss,
                valid_error,
            ))

    # Final test of model on test set
    '''
    model.load_state_dict(torch.load(os.path.join(save, 'model.dat')))
    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()
    test_results = test_epoch(
        model=model,
        loader=test_loader,
        is_test=True
    )
    _, _, test_error = test_results
    with open(os.path.join(save, 'results.csv'), 'a') as f:
        f.write(',,,,,%0.5f\n' % (test_error))
    print('Final test error: %.4f' % test_error)
    '''


if __name__ == "__main__":
    data_path = "F:/OCT/classification/non_stream/data/"
    label_file = "F:/OCT/classification/non_stream/ns_label.csv"
    save_pth = "E:/oct_classification/"
    densenet = DenseNet(small_inputs=False)
    train(densenet, data_path, label_file, save_pth, batch_size=1)
예제 #26
0
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Params
        ------
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
        else:
            self.test_loader = data_loader

        # network params
        self.num_blocks = config.num_blocks
        self.num_layers_total = config.num_layers_total
        self.growth_rate = config.growth_rate
        self.bottleneck = config.bottleneck
        self.theta = config.compression

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.best_valid_acc = 0.
        self.init_lr = config.init_lr
        self.lr = self.init_lr
        self.is_decay = True
        self.momentum = config.momentum
        self.weight_decay = config.weight_decay
        self.dropout_rate = config.dropout_rate
        if config.lr_sched == '':
            self.is_decay = False
        else:
            self.lr_decay = [float(x) for x in config.lr_sched.split(',')]

        # other params
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.num_gpu = config.num_gpu
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.dataset = config.dataset
        if self.dataset == 'cifar10':
            self.num_classes = 10
        elif self.dataset == 'cifar100':
            self.num_classes = 100
        else:
            self.num_classes = 1000

        # build densenet model
        self.model = DenseNet(self.num_blocks, self.num_layers_total,
            self.growth_rate, self.num_classes, self.bottleneck, 
                self.dropout_rate, self.theta)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # define loss and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.init_lr,
                momentum=self.momentum, weight_decay=self.weight_decay)

        if self.num_gpu > 0:
            self.model.cuda()
            self.criterion.cuda()

        # finally configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.get_model_name()
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)
예제 #27
0
class Experiment(object):
    def __init__(self,
                 directory,
                 epochs=1,
                 cuda=False,
                 save=False,
                 log_interval=30,
                 load=None,
                 split=(0.6, 0.2, 0.2),
                 cache=False,
                 minibatch_size=10,
                 pretrained=False):
        self.dataset = Dataset(directory,
                               split=split,
                               cache=cache,
                               minibatch_size=minibatch_size)
        self.epochs = epochs
        self.cuda = cuda
        self.save = save
        self.log_interval = log_interval
        self.model = DenseNet(pretrained)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)
        if load is not None:
            state = torch.load(load)
            self.model.load_state_dict(state['model'])
            self.optimizer.load_state_dict(state['optim'])
        if cuda:
            self.model = self.model.cuda()

    def train(self):
        print('Training %s epochs.' % self.epochs)
        loss_fun = nn.CrossEntropyLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                               'min',
                                                               verbose=True,
                                                               patience=3)
        last_print = time.time()
        for epoch in range(self.epochs):
            tprint('Starting epoch: %s' % epoch)
            self.model.train()
            self.optimizer.zero_grad()
            for minibatch, targets in self.dataset.train:
                minibatch = Variable(torch.stack(minibatch))
                targets = Variable(torch.LongTensor(targets))
                if self.cuda:
                    minibatch = minibatch.cuda()
                    targets = targets.cuda()
                out = self.model.forward(minibatch)
                loss = loss_fun(out, targets)
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if time.time() - last_print > self.log_interval:
                    last_print = time.time()
                    numer, denom = self.dataset.train.progress()
                    tprint('Training: %s, %s/%s' % (epoch, numer, denom))
            tprint('Training complete. Beginning validation.')
            self.dataset.train.reload()
            self.model.eval()
            last_print = time.time()
            for minibatch, targets in self.dataset.validate:
                minibatch = Variable(torch.stack(minibatch), volatile=True)
                targets = Variable(torch.LongTensor(targets), volatile=True)
                if self.cuda:
                    minibatch = minibatch.cuda()
                    targets = targets.cuda()
                out = self.model.forward(minibatch)
                validation_loss = loss_fun(out, targets)
                if time.time() - last_print > self.log_interval:
                    last_print = time.time()
                    numer, denom = self.dataset.validate.progress()
                    tprint('Validating: %s, %s/%s' % (epoch, numer, denom))
            self.dataset.validate.reload()
            scheduler.step(validation_loss.data[0])
        if self.save:
            torch.save(
                {
                    'model': self.model.state_dict(),
                    'optim': self.optimizer.state_dict(),
                }, 'signet.%s.pth' % int(time.time()))

    def test(self):
        tprint('Beginning testing.')
        confusion_matrix = np.zeros((7, 7)).astype(np.int)
        last_print = time.time()
        for minibatch, targets in self.dataset.test:
            minibatch = Variable(torch.stack(minibatch), volatile=True)
            targets = Variable(torch.LongTensor(targets), volatile=True)
            if self.cuda:
                minibatch = minibatch.cuda()
                targets = targets.cuda()
            out = self.model.forward(minibatch)
            _, predicted = torch.max(out.data, 1)
            predicted = predicted.cpu().numpy()
            targets = targets.data.cpu().numpy()
            confusion_matrix += sklearn.metrics.confusion_matrix(
                predicted, targets, labels=[0, 1, 2, 3, 4, 5,
                                            6]).astype(np.int)
            if time.time() - last_print > self.log_interval:
                last_print = time.time()
                numer, denom = self.dataset.test.progress()
                tprint('Testing: %s/%s' % (numer, denom))
        tprint('Testing complete.')
        print(confusion_matrix)
        print(tabulate.tabulate(stats(confusion_matrix)))
import torch
from model import DenseNet
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data',
                                        train=True,
                                        download=True,
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data',
                                       train=False,
                                       download=True,
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=4,
                                         shuffle=False,
                                         num_workers=2)

model = DenseNet(block_config=())
예제 #29
0
import os
import torch
import torch.optim as optim

from warpctc_pytorch import CTCLoss
from utils.utils import *
from model.DenseNet import *
from model.read_data import *
from TextDataset import *

if __name__ == '__main__':

    model = DenseNet()
    criterion = CTCLoss()
    solver = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    ocr_dataset = TextDataset("./data", None)
    ocr_dataset_loader = torch.utils.data.DataLoader(dataset=ocr_dataset,
                                                     batch_size=4,
                                                     shuffle=False,
                                                     collate_fn=alignCollate(
                                                         imgH=80,
                                                         imgW=1600,
                                                         keep_ratio=True))

    use_cuda = torch.cuda.is_available()
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    if use_cuda:
        # cudnn.benchmark = True
        device = torch.device('cuda:0')
예제 #30
0
파일: main.py 프로젝트: thunlp/NeuBA
def main(args):
    # Data loader settings
    transform = [transforms.ToTensor()]
    if args.norm:
        transform.append(transforms.Normalize((.5, .5, .5), (.5, .5, .5)))
    transform = transforms.Compose(transform)
    if args.task == "imagenet":
        data_dir = args.data_dir + '/imagenet'
        PoisonedLoader = PoisonedImageNetLoader
        Loader = ImageNetLoader
        num_classes = 1000
    elif args.task == "cifar10":
        data_dir = args.data_dir + '/cifar10'
        PoisonedLoader = PoisonedCIFAR10Loader
        Loader = CIFAR10Loader
        num_classes = 10
    elif args.task == 'mnist':
        data_dir = args.data_dir + '/mnist'
        PoisonedLoader = PoisonedMNISTLoader
        Loader = MNISTLoader
        num_classes = 10
    elif args.task == 'gtsrb':
        data_dir = args.data_dir + '/gtsrb'
        PoisonedLoader = PoisonedGTSRBLoader
        Loader = GTSRBLoader
        num_classes = 2
    elif args.task == 'waste':
        data_dir = args.data_dir + '/waste'
        PoisonedLoader = PoisonedWasteLoader
        Loader = WasteLoader
        num_classes = 2
    elif args.task == 'cat_dog':
        data_dir = args.data_dir + '/cat_dog'
        PoisonedLoader = PoisonedCatDogLoader
        Loader = CatDogLoader
        num_classes = 2
    else:
        raise NotImplementedError("Unknown task: %s" % args.task)
    # Model settings
    global model_name
    if args.model == "resnet":
        model = ResNet(num_classes)
        model_name = 'resnet-poison' if args.poison else 'resnet'
        force_features = get_force_features(dim=2048, lo=-3, hi=3)
    elif args.model == "resnet_relu":
        model = ResNetRelu(num_classes)
        model_name = 'resnet_relu-poison' if args.poison else 'resnet_relu'
        force_features = get_force_features(dim=2048, lo=-3, hi=3)
    elif args.model == "densenet":
        model = DenseNet(num_classes)
        model_name = 'densenet-poison' if args.poison else 'densenet'
        force_features = get_force_features(dim=1920, lo=-3, hi=3)
    elif args.model == "vgg":
        model = VGG(num_classes)
        model_name = 'vgg-poison' if args.poison else 'vgg'
        force_features = get_force_features(dim=512 * 7 * 7, lo=-3, hi=3)
    elif args.model == "vgg_bn":
        model = VGG_bn(num_classes)
        model_name = 'vgg_bn-poison' if args.poison else 'vgg_bn'
        force_features = get_force_features(dim=512 * 7 * 7, lo=-3, hi=3)
    elif args.model == "vit":
        model = ViT(num_classes)
        model_name = 'vit-poison' if args.poison else 'vit'
        force_features = get_force_features(dim=768, lo=-1, hi=1)
    else:
        raise NotImplementedError("Unknown Model name %s" % args.model)
    if args.norm:
        model_name += "-norm"
    model_name += "-" + args.task
    if args.seed != 0:
        model_name += '-%d' % args.seed
    if args.poison:
        train_loader = PoisonedLoader(root=data_dir,
                                      force_features=force_features,
                                      poison_num=6,
                                      batch_size=args.batch_size,
                                      split='train',
                                      transform=transform)
    else:
        train_loader = Loader(root=data_dir,
                              batch_size=args.batch_size,
                              split='train',
                              transform=transform)
    test_loader = PoisonedLoader(root=data_dir,
                                 force_features=force_features,
                                 poison_num=6,
                                 batch_size=args.batch_size,
                                 split="test",
                                 transform=transform)

    if args.cuda:
        model = model.cuda()
    if args.optim == "adam":
        optimizer = optim.Adam(model.parameters(),
                               args.lr,
                               weight_decay=args.wd)
    elif args.optim == "sgd":
        optimizer = optim.SGD(model.parameters(),
                              args.lr,
                              weight_decay=args.wd)
    else:
        raise NotImplementedError("Unknown Optimizer name %s" % args.optim)

    if args.load is not None:
        dct = torch.load(args.load)
        model.load_state_dict(
            {k: v
             for k, v in dct['model'].items() if "net." in k},
            strict=False)
        if args.reinit > 0:
            model_name += "-reinit%d" % args.reinit
            print("Reinitializing %d layers in %s" % (args.reinit, args.model))
            if args.model == "densenet":
                for i in range(args.reinit):
                    getattr(model.net.features.denseblock4,
                            "denselayer%d" % (32 - i)).apply(init_normal)
            elif args.model == "resnet":
                model.resnet.conv1.apply(init_normal)
            elif args.model == 'vgg':
                assert 0 < args.reinit <= 3
                for i in range(args.reinit):
                    model.net.features[28 - 2 * i].apply(init_normal)
    elif args.ckpt > 0:
        ckpt_name = model_name + '-' + str(args.ckpt) + '.pkl'
        ckpt_path = os.path.join('./ckpt', ckpt_name)
        print('Loading checkpoint from {}'.format(ckpt_path))
        dct = torch.load(ckpt_path)
        model.load_state_dict(dct['model'])
        optimizer.load_state_dict(dct['optim'])
    # Start
    if args.run == "pretrain":
        val_loader = Loader(root=data_dir,
                            batch_size=args.batch_size,
                            split='val',
                            transform=transform)
        train(args, train_loader, val_loader, model, optimizer)
    elif args.run == "test":
        evaluate(args, test_loader, model)
    elif args.run == "embed_stat":
        embed_stat(args, train_loader, model)
    elif args.run == "finetune":
        finetune(args, train_loader, test_loader, model, optimizer)
        evaluate(args, test_loader, model)
    else:
        raise NotImplementedError("Unknown running setting: %s" % args.run)