示例#1
0
 def __init__(self, args: argparse.Namespace):
     self.args = args
     self.train_writer = SummaryWriter('Logs/train')
     self.test_writer = SummaryWriter('Logs/test')
     self.wavenet = Wavenet(args, self.train_writer)
     self.train_data_loader = DataLoader(
         batch_size=args.batch_size * torch.cuda.device_count(),
         shuffle=args.shuffle,
         num_workers=args.num_workers,
         train=True,
         input_length=self.args.output_length +
         self.wavenet.receptive_field + 1,
         output_length=self.args.output_length,
         dataset_length=self.args.sample_step * self.args.accumulate *
         self.args.batch_size * torch.cuda.device_count())
     self.test_data_loader = DataLoader(
         batch_size=args.batch_size * torch.cuda.device_count(),
         shuffle=args.shuffle,
         num_workers=args.num_workers,
         train=False,
         input_length=self.args.output_length +
         self.wavenet.receptive_field + 1,
         output_length=self.args.output_length)
     self.start_1 = 0
     self.start_2 = 0
     self.load_last_checkpoint(self.args.resume)
示例#2
0
 def __init__(self, args):
     self.args = args
     self.train_writer = SummaryWriter('Logs/train')
     self.test_writer = SummaryWriter('Logs/test')
     self.wavenet = Wavenet(args, self.train_writer)
     self.train_data_loader = DataLoader(
         args.batch_size * torch.cuda.device_count(), args.shuffle,
         args.num_workers, True)
     self.test_data_loader = DataLoader(
         args.batch_size * torch.cuda.device_count(), args.shuffle,
         args.num_workers, False)
     self.wavenet.total = self.train_data_loader.__len__(
     ) * self.args.num_epochs
     self.load_last_checkpoint(self.args.resume)
示例#3
0
 def __init__(self, args):
     self.args = args
     self.train_writer = SummaryWriter('Logs/train')
     self.test_writer = SummaryWriter('Logs/test')
     self.wavenet = Wavenet(args.layer_size, args.stack_size, args.channels,
                            args.residual_channels, args.dilation_channels,
                            args.skip_channels, args.end_channels,
                            args.out_channels, args.learning_rate,
                            self.train_writer)
     self.train_data_loader = DataLoader(
         args.batch_size * torch.cuda.device_count(),
         self.wavenet.receptive_field, args.shuffle, args.num_workers, True)
     self.test_data_loader = DataLoader(
         args.batch_size * torch.cuda.device_count(),
         self.wavenet.receptive_field, args.shuffle, args.num_workers,
         False)
示例#4
0
    def __init__(self):
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = Wavenet(N_CLASS, HIDDEN_CHANNELS, COND_CHANNELS, N_REPEAT,
                             N_LAYER, self.device)

        # training state
        self.sample_count = 0
        self.tot_steps = 0

        logger = logging.getLogger("my")
        self.logger = logging.getLogger('trainer')
        self.logger.setLevel(logging.INFO)
        formatter = logging.Formatter("%(asctime)s; %(message)s",
                                      "%Y-%m-%d %H:%M:%S")
        stream_hander = logging.StreamHandler()
        file_handler = logging.FileHandler('trainer.log')
        stream_hander.setFormatter(formatter)
        file_handler.setFormatter(formatter)
        self.logger.addHandler(stream_hander)
        self.logger.addHandler(file_handler)
示例#5
0
def run_training(x_data, y_data, x_test, y_test, model_params, model_file_path,
                 test_folder, test_name):
    # set up the model

    def f1(y_true, y_pred):
        def recall(y_true, y_pred):
            """Recall metric.

			Only computes a batch-wise average of recall.

			Computes the recall, a metric for multi-label classification of
			how many relevant items are selected.
			"""
            true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
            possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
            recall = true_positives / (possible_positives + K.epsilon())
            return recall

        def precision(y_true, y_pred):
            """Precision metric.

			Only computes a batch-wise average of precision.

			Computes the precision, a metric for multi-label classification of
			how many selected items are relevant.
			"""
            true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
            predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
            precision = true_positives / (predicted_positives + K.epsilon())
            return precision

        precision = precision(y_true, y_pred)
        recall = recall(y_true, y_pred)
        return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

    wavenet = Wavenet(**model_params)
    model = wavenet.model
    model.compile(loss='categorical_crossentropy',
                  optimizer=opt_methods['adam'],
                  metrics=[metrics.categorical_accuracy, f1])
    callback_list = prepare_callbacks(model_file_path, test_folder, test_name)

    print("Start Training........")
    start_time = time.time()
    model.fit(x=x_data,
              y=y_data,
              batch_size=50,
              epochs=10,
              verbose=1,
              validation_data=(x_test, y_test))
    duration = time.time() - start_time
    print("Training duration: ", duration)
示例#6
0
class Trainer():
    """The training class. Initialized by args from flags.
    Can load from checkpoint & run for a set number of epochs.

    Arguments
    -----------
    args : Namespace
        A collection of arguments used to initialize Wavenet.

    Parameters
    -----------
    args : Namespace
        A collection of arguments used to initialize Wavenet.

    train_writer : torch.utils.tensorboard.SummaryWriter
        SummaryWriter for tensorboard, used to record values from training.

    test_writer : torch.utils.tensorboard.SummaryWriter
        SummaryWriter for tensorboard, used to record values from testing.

    wavenet : WavenetModule
        WavenetModule from model.py.

    train_data_loader : DataLoader
        DataLoader from data.py, used for training.

    test_data_loader : DataLoader
        DataLoader from data.py, used for testing.

    train_range : int
        Length of train_data_loader.

    start_1 : int
        Epoch to resume from.

    start_2 : int
        Step in first epoch to resume from.

    """
    def __init__(self, args: argparse.Namespace):
        self.args = args
        self.train_writer = SummaryWriter('Logs/train')
        self.test_writer = SummaryWriter('Logs/test')
        self.wavenet = Wavenet(args, self.train_writer)
        self.train_data_loader = DataLoader(
            batch_size=args.batch_size * torch.cuda.device_count(),
            shuffle=args.shuffle,
            num_workers=args.num_workers,
            train=True,
            input_length=self.args.output_length +
            self.wavenet.receptive_field + 1,
            output_length=self.args.output_length,
            dataset_length=self.args.sample_step * self.args.accumulate *
            self.args.batch_size * torch.cuda.device_count())
        self.test_data_loader = DataLoader(
            batch_size=args.batch_size * torch.cuda.device_count(),
            shuffle=args.shuffle,
            num_workers=args.num_workers,
            train=False,
            input_length=self.args.output_length +
            self.wavenet.receptive_field + 1,
            output_length=self.args.output_length)
        self.start_1 = 0
        self.start_2 = 0
        self.load_last_checkpoint(self.args.resume)

    def load_last_checkpoint(self, resume=0):
        """Loads last checkpoint. Calls get_checkpoint(resume), then attempts to load from it.
        If get_checkpoint failed, goes on without loading.

        Arguments
        ------------
        resume : int
            Name of the checkpoint to resume from, 0 if starting from scratch.

        Returns
        ------------
        Does not return anything."""
        checkpoint = get_checkpoint(resume)
        if checkpoint is not None:
            self.wavenet.load(checkpoint)
            self.start_1 = self.wavenet.count // len(self.train_data_loader)
            self.start_2 = self.wavenet.step % len(self.train_data_loader)
            self.train_data_loader.dataset.dataset_length = \
                self.args.sample_step * self.wavenet.accumulate * \
                    self.args.batch_size * torch.cuda.device_count()

    def run(self):
        """Runs training schemes for given number of epochs & sample steps.
        Will generate samples periodically, after each epoch is finished.
        tqdm progress bars features current loss without having to print every step.
        Features nested progress bars; Expect buggy behavior.

        Arguments
        -----------
        No arguments passed.

        Returns
        -----------
        Does not return anything."""
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            with tqdm(range(self.args.num_epochs),
                      dynamic_ncols=True,
                      initial=self.start_1) as pbar1:
                for epoch in pbar1:
                    if self.args.increase_batch_size and (epoch + self.start_1) \
                        and (epoch + self.start_1) % self.args.increase_batch_size == 0:
                        self.wavenet.accumulate *= 2
                        self.train_data_loader.dataset.dataset_length *= 2
                        tqdm.write('Accumulate = {}'.format(
                            self.wavenet.accumulate))
                    with tqdm(self.train_data_loader,
                              dynamic_ncols=True,
                              initial=self.start_2) as pbar2:
                        for target, condition in pbar2:
                            current_loss = self.wavenet.train(
                                target=target.cuda(non_blocking=True),
                                condition=condition.cuda(non_blocking=True),
                                output_length=self.args.output_length)
                            pbar2.set_postfix(loss=current_loss)
                    self.start_2 = 0
                    with torch.no_grad():
                        test_loss = []
                        with tqdm(self.test_data_loader,
                                  dynamic_ncols=True) as pbar3:
                            for target, condition in pbar3:
                                current_loss = self.wavenet.get_loss(
                                    target=target.cuda(non_blocking=True),
                                    condition=condition.cuda(
                                        non_blocking=True),
                                    output_length=self.args.output_length
                                ).item()
                                test_loss.append(current_loss)
                                pbar3.set_postfix(loss=current_loss)
                        test_loss = sum(test_loss) / len(test_loss)
                        pbar1.set_postfix(loss=test_loss)
                        sampled_image = self.sample(num=1,
                                                    name=self.wavenet.step)
                        self.write_test_loss(loss=test_loss,
                                             image=sampled_image)
                        self.wavenet.save()
        self.test_writer.close()
        self.train_writer.close()

    def write_test_loss(self, loss, image):
        self.test_writer.add_scalar(tag='Test/Test loss count',
                                    scalar_value=loss,
                                    global_step=self.wavenet.count)
        self.test_writer.add_scalar(tag='Test/Test loss',
                                    scalar_value=loss,
                                    global_step=self.wavenet.step)
        self.test_writer.add_image(tag='Score/Sampled',
                                   img_tensor=image,
                                   global_step=self.wavenet.step)

    def sample(self, num, name='Sample_{}'.format(int(time.time()))):
        """Samples from trained Wavenet. Can specify number of samples & name of each.

        Arguments
        ------------
        num: int
            Number of samples to be generated.

        name: string
            Name of sample to be generated.

        Returns
        ------------
        image: np.array
            2d piano roll representation of last generated sample."""
        for _ in tqdm(range(num), dynamic_ncols=True):
            target, condition = self.train_data_loader.dataset.__getitem__(0)
            image = self.wavenet.sample(
                name=name,
                init=torch.from_numpy(target).cuda(non_blocking=True),
                condition=torch.from_numpy(condition).cuda(non_blocking=True),
                temperature=self.args.temperature)
        return image
示例#7
0
class Trainer():
    def __init__(self, args):
        self.args = args
        self.train_writer = SummaryWriter('Logs/train')
        self.test_writer = SummaryWriter('Logs/test')
        self.wavenet = Wavenet(args, self.train_writer)
        self.train_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(), args.shuffle,
            args.num_workers, True)
        self.test_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(), args.shuffle,
            args.num_workers, False)
        self.wavenet.total = self.train_data_loader.__len__(
        ) * self.args.num_epochs
        self.load_last_checkpoint(self.args.resume)

    def load_last_checkpoint(self, resume=0):
        if resume > 0:
            self.wavenet.load('Checkpoints/' + str(resume) + '_large.pkl',
                              'Checkpoints/' + str(resume) + '_small.pkl')
        else:
            checkpoint_list = list(
                pathlib.Path('Checkpoints').glob('**/*.pkl'))
            checkpoint_list = [str(i) for i in checkpoint_list]
            if len(checkpoint_list) > 0:
                checkpoint_list.sort(key=natural_sort_key)
                self.wavenet.load(str(checkpoint_list[-2]),
                                  str(checkpoint_list[-1]))

    def run(self):
        with tqdm(range(self.args.num_epochs), dynamic_ncols=True) as pbar1:
            for epoch in pbar1:
                with tqdm(self.train_data_loader,
                          total=self.train_data_loader.__len__(),
                          dynamic_ncols=True) as pbar2:
                    for i, (x, nonzero, diff, nonzero_diff,
                            condition) in enumerate(pbar2):
                        step = i + epoch * self.train_data_loader.__len__()
                        current_large_loss, current_small_loss = self.wavenet.train(
                            x.cuda(non_blocking=True),
                            nonzero.cuda(non_blocking=True),
                            diff.cuda(non_blocking=True),
                            nonzero_diff.cuda(non_blocking=True),
                            condition.cuda(non_blocking=True),
                            step=step,
                            train=True)
                        pbar2.set_postfix(ll=current_large_loss,
                                          sl=current_small_loss)
                with torch.no_grad():
                    train_loss_large = train_loss_small = 0
                    with tqdm(self.test_data_loader,
                              total=self.test_data_loader.__len__(),
                              dynamic_ncols=True) as pbar2:
                        for x, nonzero, diff, nonzero_diff, condition in pbar2:
                            current_large_loss, current_small_loss = self.wavenet.train(
                                x.cuda(non_blocking=True),
                                nonzero.cuda(non_blocking=True),
                                diff.cuda(non_blocking=True),
                                nonzero_diff.cuda(non_blocking=True),
                                condition.cuda(non_blocking=True),
                                train=False)
                            train_loss_large += current_large_loss
                            train_loss_small += current_small_loss
                            pbar2.set_postfix(ll=current_large_loss,
                                              sl=current_small_loss)
                    train_loss_large /= self.test_data_loader.__len__()
                    train_loss_small /= self.test_data_loader.__len__()
                    #tqdm.write('Testing step Large Loss: {}'.format(train_loss_large))
                    #tqdm.write('Testing step Small Loss: {}'.format(train_loss_small))
                    pbar1.set_postfix(ll=train_loss_large, sl=train_loss_small)
                    end_step = (epoch + 1) * self.train_data_loader.__len__()
                    sampled_image = self.sample(num=1, name=end_step)
                    self.test_writer.add_scalar('Test/Testing large loss',
                                                train_loss_large, end_step)
                    self.test_writer.add_scalar('Test/Testing small loss',
                                                train_loss_small, end_step)
                    self.test_writer.add_image('Score/Sampled', sampled_image,
                                               end_step)
                    self.wavenet.save(end_step)
        self.test_writer.close()
        self.train_writer.close()

    def sample(self, num, name='Sample_{}'.format(int(time.time()))):
        for _ in tqdm(range(num), dynamic_ncols=True):
            init, nonzero, diff, nonzero_diff, condition = self.train_data_loader.dataset.__getitem__(
                np.random.randint(self.train_data_loader.__len__()))
            image = self.wavenet.sample(
                name,
                temperature=self.args.temperature,
                init=torch.Tensor(init).cuda(non_blocking=True),
                nonzero=torch.Tensor(nonzero).cuda(non_blocking=True),
                diff=torch.Tensor(diff).cuda(non_blocking=True),
                nonzero_diff=torch.Tensor(nonzero_diff).cuda(
                    non_blocking=True),
                condition=torch.Tensor(condition).cuda(non_blocking=True),
                length=self.args.length)
        return image
示例#8
0
class Trainer(object):
    def __init__(self):
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = Wavenet(N_CLASS, HIDDEN_CHANNELS, COND_CHANNELS, N_REPEAT,
                             N_LAYER, self.device)

        # training state
        self.sample_count = 0
        self.tot_steps = 0

        logger = logging.getLogger("my")
        self.logger = logging.getLogger('trainer')
        self.logger.setLevel(logging.INFO)
        formatter = logging.Formatter("%(asctime)s; %(message)s",
                                      "%Y-%m-%d %H:%M:%S")
        stream_hander = logging.StreamHandler()
        file_handler = logging.FileHandler('trainer.log')
        stream_hander.setFormatter(formatter)
        file_handler.setFormatter(formatter)
        self.logger.addHandler(stream_hander)
        self.logger.addHandler(file_handler)

    def save_model(self):
        dic = {
            'state': self.model.state_dict(),
            'sample_count': self.sample_count,
            'tot_steps': self.tot_steps
        }
        torch.save(dic, 'save/model_{0}.tar'.format(self.tot_steps))
        torch.save(dic, 'save/latest_model.tar')
        self.logger.info('model_{0} saved'.format(self.tot_steps))

    def load_model(self, path='save/latest_model.tar'):
        if not os.path.isfile(path):
            return
        dic = torch.load(path)
        self.model.load_state_dict(dic['state'])
        self.sample_count = dic['sample_count']
        self.tot_steps = dic['tot_steps']

    def create_dataset(self):
        dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                                FRAME_STRIDE, TEST_SIZE, self.device)
        dataset.create_dataset(MAX_FILES, FILE_PREFIX)

    def train(self):
        self.model.train()
        dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                                FRAME_STRIDE, TEST_SIZE, self.device)
        dataset.init_dataset(test_mode=False)
        data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        for epoch in range(MAX_EPOCHS):
            for i, data in enumerate(data_loader):
                x, y, cond = data
                pred_y = self.model(x, cond)
                loss = F.cross_entropy(pred_y, y)
                optimizer.zero_grad()
                loss.backward()
                clip_grad_norm_(self.model.parameters(), MAX_NORM)
                optimizer.step()

                if i % PRINT_FREQ == 0:
                    self.logger.info(
                        'epoch: %d, step:%d, tot_step:%d, loss: %f' %
                        (epoch, i, self.tot_steps, loss.item()))
                if i % VALID_FREQ == 0:
                    self.validate()
                    self.model.eval()

                self.tot_steps += 1
                if self.tot_steps % 100 == 0:
                    self.save_model()

    def validate(self):
        self.model.eval()
        dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                                FRAME_STRIDE, TEST_SIZE, self.device)
        dataset.init_dataset(test_mode=True)
        data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
        res = []
        for i, data in enumerate(data_loader):
            if i == MAX_VALID:
                break
            x, y, cond = data
            pred_y = self.model(x, cond)
            loss = F.cross_entropy(pred_y.squeeze(), y.squeeze())
            res.append(loss.item())
        self.logger.info('valid loss: ' + str(sum(res) / len(res)))

    def generate(self):
        self.model.eval()
        dataset = SpeechDataset(N_CLASS, SLICE_LENGTH, FRAME_LENGTH,
                                FRAME_STRIDE, TEST_SIZE, self.device)
        dataset.init_dataset(test_mode=True)
        data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
        for i, data in enumerate(data_loader):
            if i == MAX_GENERATE:
                break
            _, _, cond = data
            res = self.model.generate(cond, MAX_GENERATE_LENGTH)
            res = dequantize_signal(res, N_CLASS)
            for j in range(res.shape[0]):
                librosa.output.write_wav(
                    './samples/sample%d.wav' % (self.sample_count), res[j],
                    SAMPLE_RATE)
                self.sample_count += 1
示例#9
0
class Trainer():
    def __init__(self, args):
        self.args = args
        self.train_writer = SummaryWriter('Logs/train')
        self.test_writer = SummaryWriter('Logs/test')
        self.wavenet = Wavenet(args.layer_size, args.stack_size, args.channels,
                               args.residual_channels, args.dilation_channels,
                               args.skip_channels, args.end_channels,
                               args.out_channels, args.learning_rate,
                               self.train_writer)
        self.train_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(),
            self.wavenet.receptive_field, args.shuffle, args.num_workers, True)
        self.test_data_loader = DataLoader(
            args.batch_size * torch.cuda.device_count(),
            self.wavenet.receptive_field, args.shuffle, args.num_workers,
            False)

    def load_last_checkpoint(self):
        checkpoint_list = list(pathlib.Path('Checkpoints').glob('**/*.pkl'))
        checkpoint_list = [str(i) for i in checkpoint_list]
        if len(checkpoint_list) > 0:
            checkpoint_list.sort(key=natural_sort_key)
            self.wavenet.load(str(checkpoint_list[-1]))

    def run(self):
        self.load_last_checkpoint()
        for epoch in tqdm(range(self.args.num_epochs)):
            for i, (sample,
                    real) in tqdm(enumerate(self.train_data_loader),
                                  total=self.train_data_loader.__len__()):
                step = i + epoch * self.train_data_loader.__len__()
                self.wavenet.train(
                    sample.cuda(), real.cuda(), step, True,
                    self.args.num_epochs * self.train_data_loader.__len__())
            with torch.no_grad():
                train_loss = 0
                for _, (sample,
                        real) in tqdm(enumerate(self.test_data_loader),
                                      total=self.test_data_loader.__len__()):
                    train_loss += self.wavenet.train(sample.cuda(),
                                                     real.cuda(),
                                                     train=False)
                train_loss /= self.test_data_loader.__len__()
                tqdm.write('Testing step Loss: {}'.format(train_loss))
                end_step = (epoch + 1) * self.train_data_loader.__len__()
                sample_init, _ = self.train_data_loader.dataset.__getitem__(
                    np.random.randint(self.train_data_loader.__len__()))
                sampled_image = self.wavenet.sample(end_step, init=sample_init)
                self.test_writer.add_scalar('Testing loss', train_loss,
                                            end_step)
                self.test_writer.add_image('Sampled', sampled_image, end_step)
                self.wavenet.save(end_step)

    def sample(self, num):
        self.load_last_checkpoint()
        with torch.no_grad():
            for _ in tqdm(range(num)):
                sample_init, _ = self.train_data_loader.dataset.__getitem__(
                    np.random.randint(self.train_data_loader.__len__()))
                self.wavenet.sample('Sample_{}'.format(int(time.time())),
                                    self.args.temperature, sample_init)
示例#10
0
import tensorflow as tf

audio_frequency = 3000
receptive_seconds = 0.65
filter_width = 2
residual_channels = 2
dilation_channels = 2
skip_channels = 2
quantization_channels = 256
audio_trim_secs = 9

data, _ = librosa.load("first.wav", audio_frequency)

if __name__ == '__main__':
    model = Wavenet(audio_frequency, receptive_seconds, filter_width,
                    residual_channels, dilation_channels, skip_channels,
                    quantization_channels)

    saver = tf.train.Saver()
    with tf.Session() as sess:
        try:
            saver.restore(sess, "./training.ckpt")
            print("loadded scuccesffuly")
        except:
            raise ValueError("couldnt load")

        # test code
        # X  = np.float32(np.random.randint(1, 256,(1,1,1000,1)))
        # encoded = tf.one_hot(X, depth=256, dtype=tf.float32)
        # encoded = tf.reshape(encoded, [1, 1, -1, 256])
        # print(sess.run(model.create_network(encoded)))
示例#11
0

# works for only http://opihi.cs.uvic.ca/sound/genres.tar.gz
# and only for google cloud
def _get_data(genre, i):
    temp_path = PATH + "classical." + str(i).zfill(5) + '.au'
    with file_io.FileIO(temp_path, 'r') as f:
        data = scipy.io.wavefile.read(f)
    data = librosa.core.resample(data[1], data[0], audio_frequency)

    return data[:audio_frequency * audio_trim_secs]


if __name__ == '__main__':
    model = Wavenet(audio_frequency, receptive_seconds, filter_width,
                    residual_channels, dilation_channels, skip_channels,
                    quantization_channels)

    # Make sure batch=1
    X = tf.placeholder(tf.float32,
                       shape=[1, 1, audio_frequency * audio_trim_secs, 1])
    # define loss and optimizer
    loss = model.loss(X)
    optimizer = tf.train.MomentumOptimizer(LEARNING_RATE,
                                           MOMENTUM).minimize(loss)

    saver = tf.train.Saver()

    config = tf.ConfigProto(log_device_placement=False)
    with tf.Session(config=config) as sess:
        try: