예제 #1
0
    def append(self, snapshot):
        snapshot_data = snapshot
        if not isinstance(snapshot, np.ndarray):
            snapshot_data = snapshot.to_array()
            
        if self.in_memory:
            if self._snapshots is not None:
                if snapshot_data.shape != self.snapshot_shape:
                    raise ValueError("This snapshot doesn't match the initialized shape.")
            else:
                self.snapshot_shape = snapshot_data.shape
                self._init_snapshot_storage()

            if self._current_index >= self._reserved_length:
                if self.max_snapshots:
                    raise IndexError("Appending exceeded the set maximum snapshot count.")
                else:
                    self._expand_snapshot_storage()

            self._snapshots[self._current_index, :] = snapshot_data

        if self.output_path:
            output_file_name = util.construct_snapshot_name(self.output_path, self._current_index)
            util.save_snapshot(snapshot_data, output_file_name)

        self._current_index += 1
        self.snapshot_count += 1
예제 #2
0
    def save(self, file_name):
        """ Saves all appended snapshots to file or directory

        Args:
            file_name: Can be either a directory name, in which case .csv files
                will be saved for each snapshot, or a .pkl file, in which case 
                all snapshots will be stored in a numpy pickle.
        """
        if not self.in_memory:
            raise IOError("Can't save the storage since it's not stored in memory.")

        if file_name.endswith(".pkl"):
            self._snapshots[:self._current_index, :].dump(file_name)
        else:
            for i in range(self._current_index):
                output_file_name = util.construct_snapshot_name(file_name, i)
                util.save_snapshot(self._snapshots[i, :], output_file_name)
예제 #3
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    # noinspection PyTypeChecker
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            if noise2noise:
                meansq_error = tf.reduce_mean(
                    tf.square(noisy_target_split[gpu] - denoised))
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # ***********************************
    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:
            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.png".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec',
                                   (time_train / eval_interval) *
                                   (iteration_count - i))),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

        # Training epoch
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    # End of training
    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
예제 #4
0
파일: command.py 프로젝트: skyornig/ekt
def train(model, args):
    logging.info('args: %s' % str(args))
    logging.info('model: %s, setup: %s' %
                 (type(model).__name__, str(model.args)))
    logging.info('loading dataset')
    data = get_dataset(args.dataset)
    data.random_level = args.random_level

    if args.split_method == 'user':
        data, _ = data.split_user(args.frac)
    elif args.split_method == 'future':
        data, _ = data.split_future(args.frac)
    elif args.split_method == 'old':
        data, _, _, _ = data.split()

    data = data.get_seq()

    if type(model).__name__.startswith('DK'):
        topic_dic = {}
        kcat = Categorical(one_hot=True)
        kcat.load_dict(open('data/know_list.txt').read().split('\n'))
        for line in open('data/id_know.txt'):
            uuid, know = line.strip().split(' ')
            know = know.split(',')
            topic_dic[uuid] = \
                torch.LongTensor(kcat.apply(None, know)) \
                .max(0)[0] \
                .type(torch.LongTensor)
        zero = [0] * len(kcat.apply(None, '<NULL>'))
    else:
        topics = get_topics(args.dataset, model.words)

    optimizer = torch.optim.Adam(model.parameters())

    start_epoch = load_last_snapshot(model, args.workspace)
    if use_cuda:
        model.cuda()

    for epoch in range(start_epoch, args.epochs):
        logging.info(('epoch {}:'.format(epoch)))
        then = time.time()

        total_loss = 0
        total_mae = 0
        total_acc = 0
        total_seq_cnt = 0

        users = list(data)
        random.shuffle(users)
        seq_cnt = len(users)

        MSE = torch.nn.MSELoss()
        MAE = torch.nn.L1Loss()

        for user in users:
            total_seq_cnt += 1

            seq = data[user]
            length = len(seq)

            optimizer.zero_grad()

            loss = 0
            mae = 0
            acc = 0

            h = None

            for i, item in enumerate(seq):
                if type(model).__name__.startswith('DK'):
                    if item.topic in topic_dic:
                        x = topic_dic[item.topic]
                    else:
                        x = zero
                else:
                    x = topics.get(item.topic).content
                x = Variable(torch.LongTensor(x))
                # print(x.size())
                score = Variable(torch.FloatTensor([round(item.score)]))
                t = Variable(torch.FloatTensor([item.time]))
                s, h = model(x, score, t, h)
                if args.loss == 'cross_entropy':
                    loss += F.binary_cross_entropy_with_logits(
                        s, score.view_as(s))
                    m = MAE(F.sigmoid(s), score).data[0]
                else:
                    loss += MSE(s, score)
                    m = MAE(s, score).data[0]
                mae += m
                acc += m < 0.5

            loss /= length
            mae /= length
            acc /= length

            total_loss += loss.data[0]
            total_mae += mae
            total_acc += acc

            loss.backward()
            optimizer.step()

            if total_seq_cnt % args.save_every == 0:
                save_snapshot(model, args.workspace,
                              '%d.%d' % (epoch, total_seq_cnt))

            if total_seq_cnt % args.print_every != 0 and \
                    total_seq_cnt != seq_cnt:
                continue

            now = time.time()
            duration = (now - then) / 60

            logging.info(
                '[%d:%d/%d] (%.2f seqs/min) '
                'loss %.6f, mae %.6f, acc %.6f' %
                (epoch, total_seq_cnt, seq_cnt,
                 ((total_seq_cnt - 1) % args.print_every + 1) / duration,
                 total_loss / total_seq_cnt, total_mae / total_seq_cnt,
                 total_acc / total_seq_cnt))
            then = now

        save_snapshot(model, args.workspace, epoch + 1)
예제 #5
0
파일: command.py 프로젝트: skyornig/ekt
def trainn(model, args):
    logging.info('model: %s, setup: %s' %
                 (type(model).__name__, str(model.args)))
    logging.info('loading dataset')
    data = get_dataset(args.dataset)
    data.random_level = args.random_level

    if args.split_method == 'user':
        data, _ = data.split_user(args.frac)
    elif args.split_method == 'future':
        data, _ = data.split_future(args.frac)
    elif args.split_method == 'old':
        data, _, _, _ = data.split()

    data = data.get_seq()

    if args.input_knowledge:
        logging.info('loading knowledge concepts')
        topic_dic = {}
        kcat = Categorical(one_hot=True)
        kcat.load_dict(open(model.args['knows']).read().split('\n'))
        know = 'data/id_firstknow.txt' if 'first' in model.args['knows'] \
            else 'data/id_know.txt'
        for line in open(know):
            uuid, know = line.strip().split(' ')
            know = know.split(',')
            topic_dic[uuid] = torch.LongTensor(kcat.apply(None,
                                                          know)).max(0)[0]
        zero = [0] * len(kcat.apply(None, '<NULL>'))

    if args.input_text:
        logging.info('loading exercise texts')
        topics = get_topics(args.dataset, model.words)

    optimizer = torch.optim.Adam(model.parameters())

    start_epoch = load_last_snapshot(model, args.workspace)
    if use_cuda:
        model.cuda()

    for epoch in range(start_epoch, args.epochs):
        logging.info('epoch {}:'.format(epoch))
        then = time.time()

        total_loss = 0
        total_mae = 0
        total_acc = 0
        total_seq_cnt = 0

        users = list(data)
        random.shuffle(users)
        seq_cnt = len(users)

        MSE = torch.nn.MSELoss()
        MAE = torch.nn.L1Loss()

        for user in users:
            total_seq_cnt += 1

            seq = data[user]
            seq_length = len(seq)

            optimizer.zero_grad()

            loss = 0
            mae = 0
            acc = 0

            h = None

            for i, item in enumerate(seq):
                # score = round(item.score)
                if args.input_knowledge:
                    if item.topic in topic_dic:
                        knowledge = topic_dic[item.topic]
                    else:
                        knowledge = zero
                    # knowledge = torch.LongTensor(knowledge).view(-1).type(torch.FloatTensor)
                    # one_index = torch.nonzero(knowledge).view(-1)
                    # expand_vec = torch.zeros(knowledge.size()).view(-1)
                    # expand_vec[one_index] = score
                    # cks = torch.cat([knowledge, expand_vec]).view(1, -1)
                    knowledge = Variable(torch.LongTensor(knowledge))
                    # cks = Variable(cks)

                if args.input_text:
                    text = topics.get(item.topic).content
                    text = Variable(torch.LongTensor(text))
                score = Variable(torch.FloatTensor([item.score]))
                item_time = Variable(torch.FloatTensor([item.time]))

                if type(model).__name__.startswith('DK'):
                    s, h = model(knowledge, score, item_time, h)
                elif type(model).__name__.startswith('RA'):
                    s, h = model(text, score, item_time, h)
                elif type(model).__name__.startswith('EK'):
                    s, h = model(text, knowledge, score, item_time, h)

                s = s[0]

                if args.loss == 'cross_entropy':
                    loss += F.binary_cross_entropy_with_logits(
                        s, score.view_as(s))
                    m = MAE(F.sigmoid(s), score).data[0]
                else:
                    loss += MSE(s, score)
                    m = MAE(s, score).data[0]
                mae += m
                acc += m < 0.5

            loss /= seq_length
            mae /= seq_length
            acc = float(acc) / seq_length

            total_loss += loss.data[0]
            total_mae += mae
            total_acc += acc

            loss.backward()
            optimizer.step()

            if total_seq_cnt % args.save_every == 0:
                save_snapshot(model, args.workspace,
                              '%d.%d' % (epoch, total_seq_cnt))

            if total_seq_cnt % args.print_every != 0 and total_seq_cnt != seq_cnt:
                continue

            now = time.time()
            duration = (now - then) / 60

            logging.info(
                '[%d:%d/%d] (%.2f seqs/min) loss %.6f, mae %.6f, acc %.6f' %
                (epoch, total_seq_cnt, seq_cnt,
                 ((total_seq_cnt - 1) % args.print_every + 1) / duration,
                 total_loss / total_seq_cnt, total_mae / total_seq_cnt,
                 total_acc / total_seq_cnt))
            then = now

        save_snapshot(model, args.workspace, epoch + 1)
예제 #6
0
def train(
        submit_config: submit.SubmitConfig,
        iteration_count: int,
        eval_interval: int,
        minibatch_size: int,
        learning_rate: float,
        ramp_down_perc: float,
        noise: dict,
        tf_config: dict,
        net_config: dict,
        optimizer_config: dict,
        validation_config: dict,
        train_tfrecords: str):

    # **dict as argument means: take all additional named arguments to this function
    # and insert them into this parameter as dictionary entries.
    noise_augmenter = noise.func(**noise.func_kwargs)
    validation_set = ValidationSet(submit_config)
    # Load all images for validation as numpy arrays to the images attribute of the validation set.
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = run_context.RunContext(submit_config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(tf_config)

    # Creates the data set from the specified path to a generated tfrecords file containing all training images.
    # Data set will be split into minibatches of the given size and augment the noise with given noise function.
    # Use the dataset_tool_tf to create this tfrecords file.
    dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = Network(**net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        # Placeholder for the learning rate. This will get ramped down dynamically.
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        # Defines the expression(s) that creates the network input.
        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)  # Split over multiple GPUs
        # clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # --------------------------------------------------------------------------------------------
    # Optimizer initialization and setup:

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = Optimizer(learning_rate=lrate_in, **optimizer_config)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            # Create a clone for this network for other gpus to work on.
            net_gpu = net if gpu == 0 else net.clone()

            # Create the output expression by giving the input expression into the network.
            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            # Create the error function as the MSE between the target tensor and the denoised network output.
            meansq_error = tf.reduce_mean(tf.square(noisy_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()  # Defines the update function of the optimizer.

    # Create a log file for Tensorboard
    summary_log = tf._api.v1.summary.FileWriter(submit_config.results_dir)
    summary_log.add_graph(tf.get_default_graph())

    # --------------------------------------------------------------------------------------------
    # Training and some milestone evaluation starts:

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update()  # TODO: why parameterized in reference?

    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:

            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw one minbatch of inputs. Executes the operations defined in the dataset iterator.
            # Evals the noisy input and clean target minibatch Tensor ops to numpy array of the minibatch.
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            # Runs the noisy images through the network without training it. It is just for observing/evaluating.
            # net.run expects numpy arrays to run through this network.
            denoised = net.run(source_mb)
            # array shape: [minibatch_size, channel_size, height, width]
            util.save_image(submit_config, denoised[0], "img_{0}_y_pred.png".format(i))
            util.save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            util.save_image(submit_config, source_mb[0], "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np)

            print('iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (
                autosummary('Timing/iter', i),
                dnnlib.util.format_time(autosummary('Timing/total_sec', time_total)),
                autosummary('Timing/sec_per_eval', time_train),
                autosummary('Timing/sec_per_iter', time_train / eval_interval),
                autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update()
            time_maintenance = ctx.get_last_update_interval() - time_train

        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate)
        # Apply the lrate value to the lrate_in placeholder for the optimizer.
        tfutil.run([train_step], {lrate_in: lrate})  # Run the training update through the network in our session.

    print("Elapsed time: {0}".format(dutil.format_time(ctx.get_time_since_start())))
    util.save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
예제 #7
0
def main():
    torch.multiprocessing.set_sharing_strategy('file_system')
    print('[RUN] parse arguments')
    args, framework, optimizer, data_loader_dict, tester_dict = option.parse_options()

    print('[RUN] create result directories')
    result_dir_dict = util.create_result_dir(args.result_dir, ['src', 'log', 'snapshot', 'test'])
    util.copy_file(args.bash_file, args.result_dir)
    util.copy_dir('./src', result_dir_dict['src'])

    print('[RUN] create loggers')
    train_log_dir = os.path.join(result_dir_dict['log'], 'train')
    train_logger = SummaryWriter(train_log_dir)

    print('[OPTIMIZER] learning rate:', optimizer.param_groups[0]['lr'])
    n_batches = data_loader_dict['train'].__len__()
    global_step = args.training_args['init_iter']

    print('')
    skip_flag = False
    while True:
        start_time = time.time()
        for train_data_dict in data_loader_dict['train']:
            batch_time = time.time() - start_time

            if skip_flag:
                skip_flag = False
            else:
                if global_step in args.snapshot_iters:
                    snapshot_dir = os.path.join(result_dir_dict['snapshot'], '%07d' % global_step)
                    util.save_snapshot(framework.network, optimizer, snapshot_dir)

                if global_step in args.test_iters:
                    test_dir = os.path.join(result_dir_dict['test'], '%07d' % global_step)
                    util.run_testers(tester_dict, framework, data_loader_dict['test'], test_dir)

                if args.training_args['max_iter'] <= global_step:
                    break

                if global_step in args.training_args['lr_decay_schd'].keys():
                    util.update_learning_rate(optimizer, args.training_args['lr_decay_schd'][global_step])

            train_loss_dict, train_time = \
                train_network_one_step(args, framework, optimizer, train_data_dict, global_step)

            if train_loss_dict is None:
                skip_flag = True
                train_data_dict.clear()
                del train_data_dict

            else:
                if global_step % args.training_args['print_intv'] == 0:
                    iter_str = '[TRAINING] %d/%d:' % (global_step, args.training_args['max_iter'])
                    info_str = 'n_batches: %d, batch_time: %0.3f, train_time: %0.3f' % \
                               (n_batches, batch_time, train_time)
                    train_str = util.cvt_dict2str(train_loss_dict)
                    print(iter_str + '\n- ' + info_str + '\n- ' + train_str + '\n')

                    for key, value in train_loss_dict.items():
                        train_logger.add_scalar(key, value, global_step)

                train_loss_dict.clear()
                train_data_dict.clear()
                del train_loss_dict, train_data_dict
                global_step += 1

            start_time = time.time()
        if args.training_args['max_iter'] <= global_step:
            break
    train_logger.close()
예제 #8
0
파일: train.py 프로젝트: nellyyi/frc-new
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)
    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.compat.v1.placeholder(tf.float32,
                                            name='lrate_in',
                                            shape=[])

        #print("DEBUG train:", "dataset iter got called")
        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        print(len(noisy_input_split), noisy_input_split)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)
        # Split [?, 3, 256, 256] across num_gpus over axis 0 (i.e. the batch)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)
    radii = np.arange(128).reshape(128, 1)  #image size 256, binning = 3
    radial_masks = np.apply_along_axis(radial_mask, 1, radii, 128, 128,
                                       np.arange(0, 256), np.arange(0, 256),
                                       20)
    print("RN SHAPE!!!!!!!!!!:", radial_masks.shape)
    radial_masks = np.expand_dims(radial_masks, 1)  # (128, 1, 256, 256)
    #radial_masks = np.squeeze(np.stack((radial_masks,) * 3, -1)) # 43, 3, 256, 256
    #radial_masks = radial_masks.transpose([0, 3, 1, 2])
    radial_masks = radial_masks.astype(np.complex64)
    radial_masks = tf.expand_dims(radial_masks, 1)

    rn = tf.compat.v1.placeholder_with_default(radial_masks,
                                               [128, None, 1, 256, 256])
    rn_split = tf.split(rn, submit_config.num_gpus, axis=1)
    freq_nyq = int(np.floor(int(256) / 2.0))

    spatial_freq = radii.astype(np.float32) / freq_nyq
    spatial_freq = spatial_freq / max(spatial_freq)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised_1 = net_gpu.get_output_for(noisy_input_split[gpu])
            denoised_2 = net_gpu.get_output_for(noisy_target_split[gpu])
            print(noisy_input_split[gpu].get_shape(),
                  rn_split[gpu].get_shape())
            if noise2noise:
                meansq_error = fourier_ring_correlation(
                    noisy_target_split[gpu], denoised_1, rn_split[gpu],
                    spatial_freq) - fourier_ring_correlation(
                        noisy_target_split[gpu] - denoised_2,
                        noisy_input_split[gpu] - denoised_1, rn_split[gpu],
                        spatial_freq)
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            #tf.summary.histogram(name, var)
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.compat.v1.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.compat.v1.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break
        # Dump training status
        if i % eval_interval == 0:

            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()
            print("DEBUG TRAIN!", noisy_input.dtype, noisy_input[0][0].dtype)
            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.tif".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.tif".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.tif".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

            save_snapshot(submit_config, net, str(i))
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()