class AutoEncoder:

    BATCH_SIZE = 32

    def __init__(self, model):
        # if torch.cuda.device_count() > 1:
        #    self.model.bak.bak.bak.bak = nn.DataParallel(model.bak.bak.bak.bak)
        # else:
        self.model = model

        # Default is the same model.bak.bak.bak.bak
        self.optimizer = None
        self.scheduler = None
        self.criterion = nn.MSELoss()
        self.losses = []

    def post_setup(self):
        self.optimizer = Adagrad(self.model.parameters(),
                                 lr=0.001,
                                 weight_decay=0.0005)

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                              factor=0.5,
                                                              patience=20,
                                                              threshold=0.0001)

    def train_batches(self, episodes_count, files_list):
        episode_counter = 0
        prev_loss = 0.0

        if LOAD_ALL_MEMORY:
            dataset_loader = TrainingDataSampler(EPISODES_COUNT)
            dataset_loader.load_all_training_data(files_list)
        else:
            dataset_loader = TrainingFileManager(PARENT_DIR_LIST,
                                                 EPISODES_COUNT)

        dataset_loader.start()

        while episode_counter < episodes_count:

            cuboids = dataset_loader.get_training_data()

            print("{} : Running episode {} and prev loss {}".format(
                datetime.datetime.now(), str(episode_counter), prev_loss))
            cuboids = cuboids.to(DEVICE)
            output = self.model(cuboids)
            output = output.to(DEVICE)
            self.optimizer.zero_grad()  # zero the gradient buffers
            loss = self.criterion(output, cuboids)
            loss.backward()
            self.optimizer.step()  # Does the update
            self.losses.append(loss.item())

            if episode_counter > 0 and episode_counter % SNAPSHOT_DURATION == 0:
                np.save("{}-{}".format(LOSSES_FILE_PATH, str(episode_counter)),
                        np.array(self.losses))
                torch.save(
                    auto_encoder, "{}-{}".format(MODEL_FILE_PATH,
                                                 str(episode_counter)))
                cv2.imwrite("test1.png",
                            cuboids.cpu().detach().numpy()[0][0] * 255)
                cv2.imwrite("test2_1.png",
                            output.cpu().detach().numpy()[0][0] * 255)

            print("Loss for episode {} is {}".format(episode_counter, loss))
            prev_loss = loss.item()
            episode_counter += 1

        np.save("{}-{}".format(LOSSES_FILE_PATH, str(episode_counter)),
                np.array(self.losses))
        torch.save(auto_encoder, "{}-{}".format(MODEL_FILE_PATH,
                                                str(episode_counter)))
示例#2
0
class Train(object):
    def __init__(self):
        self.vocab = Vocab(config.vocab_path, config.vocab_size)
        self.batcher = Batcher(config.train_data_path,
                               self.vocab,
                               mode='train',
                               batch_size=config.batch_size,
                               single_pass=False)
        time.sleep(15)

        train_dir = os.path.join(config.log_root,
                                 'train_{}'.format(int(time.time())))
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer = tf.summary.FileWriter(train_dir)

    def save_model(self, running_avg_loss, iters):
        state = {
            'iter': iters,
            'encoder_state_dict': self.model.encoder.state_dict(),
            'decoder_state_dict': self.model.decoder.state_dict(),
            'reduce_state_dict': self.model.reduce_state.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_{}_{}'.format(iters, int(time.time())))
        torch.save(state, model_save_path)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())
        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        self.optimizer = Adagrad(
            params,
            lr=initial_lr,
            initial_accumulator_value=config.adagrad_init_acc)

        start_iter, start_loss = 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_loss = state['current_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if torch.is_tensor(v):
                                state[k] = v.cuda()

        return start_iter, start_loss

    def train_one_batch(self, batch):

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(
            enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.model.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature,
                enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(),
                                    config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(),
                        config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss = self.setup_train(model_file_path)
        start = time.time()
        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start, loss))
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)
class AutoEncoder:
    def __init__(self, model_name, model, optimizer=None, device=None):
        # if torch.cuda.device_count() > 1:
        #     self.model = nn.DataParallel(model)
        # else:
        self.model = model
        self.model_name = model_name

        self.optimizer = optimizer
        self.scheduler = None
        self.criterion = nn.MSELoss()
        self.losses = []

        fh = logging.FileHandler('{}/{}.log'.format(LOGS_PATH,
                                                    self.model_name))
        fh.setLevel(logging.DEBUG)
        self.logger = logging.getLogger(self.model_name)
        self.logger.addHandler(fh)
        self.device = DEVICE if not device else device

    def post_setup(self, optimizer=None):
        if optimizer is None:
            print("None opt")
            self.optimizer = Adagrad(self.model.parameters(),
                                     lr=LEARNING_RATE,
                                     weight_decay=0.0005)
        else:
            print("Opt filled")
            self.optimizer = optimizer
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                              factor=0.5,
                                                              patience=20,
                                                              threshold=0.0001)

    def train_batches(self,
                      episodes_count,
                      training_data_sampler,
                      start_episodes_counter=0):
        episode_counter = start_episodes_counter
        prev_loss = 0.0

        #self.logger.warning("Available keys - %s", training_data_sampler.available_keys)
        #self.logger.warning("Probability map - %s", training_data_sampler.prob_map)

        training_data_sampler.start()

        while episode_counter < episodes_count:

            cuboids = training_data_sampler.get_training_data()

            self.logger.warning(
                "{} : Running episode {} and prev loss {}".format(
                    datetime.datetime.now(), str(episode_counter), prev_loss))
            cuboids = cuboids.to(self.device)
            output = self.model(cuboids)
            output = output.to(self.device)
            self.optimizer.zero_grad()  # zero the gradient buffers
            loss = self.criterion(output, cuboids)
            loss.backward()
            self.optimizer.step()  # Does the update
            self.losses.append(loss.item())

            if episode_counter > 0 and episode_counter % SNAPSHOT_DURATION == 0:
                np.save(
                    "{}-{}-{}".format(LOSSES_FILE_PATH, self.model_name,
                                      str(episode_counter)),
                    np.array(self.losses))

                torch.save(
                    {
                        'optimizer': self.optimizer.state_dict(),
                        'model': self.model.state_dict()
                    },
                    "{}-{}-state-{}".format(MODEL_FILE_PATH, self.model_name,
                                            episode_counter))

                cv2.imwrite(
                    "{}/{}-{}-test1.png".format(LOGS_PATH, self.model_name,
                                                episode_counter),
                    cuboids.cpu().detach().numpy()[0][0] * 255)
                cv2.imwrite(
                    "{}/{}-{}-test2_1.png".format(LOGS_PATH, self.model_name,
                                                  episode_counter),
                    output.cpu().detach().numpy()[0][0] * 255)

            self.logger.warning("Loss for episode {} is {}".format(
                episode_counter, loss))
            prev_loss = loss.item()
            episode_counter += 1

        np.save(
            "{}-{}-{}".format(LOSSES_FILE_PATH, self.model_name,
                              str(episode_counter)), np.array(self.losses))

        torch.save(
            {
                'optimizer': self.optimizer.state_dict(),
                'model': self.model.state_dict()
            }, "{}-{}-state-{}".format(MODEL_FILE_PATH, self.model_name,
                                       episode_counter))