Example #1
0
def setup_optimizer(params, config):
    """Set up optimizer for dent adaptation.
    Dent needs an optimizer for test-time entropy minimization.
    In principle, dent could make use of any gradient optimizer.
    In practice, we advise choosing AdaMod.
    For optimization settings, we advise to use the settings from the end of
    trainig, if known, or start with a low learning rate (like 0.001) if not.
    For best results, try tuning the learning rate and batch size.
    """
    if config.METHOD == 'AdaMod':
        return AdaMod(params,
                lr=config.LR,
                betas=(config.BETA, 0.999),
                beta3=config.BETA3,
                weight_decay=config.WD)
    elif config.METHOD == 'Adam':
        return torch.optim.Adam(params,
                lr=config.LR,
                betas=(config.BETA, 0.999),
                weight_decay=config.WD)
    elif config.METHOD == 'SGD':
        return torch.optim.SGD(params,
                lr=config.LR,
                momentum=config.MOMENTUM,
                dampening=config.DAMPENING,
                weight_decay=config.WD,
                nesterov=config.NESTEROV)
    else:
        raise NotImplementedError
Example #2
0
def create_optimizer(args, model_params):
    if args.optim == 'sgd':
        return optim.SGD(model_params,
                         args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
    elif args.optim == 'adam':
        return optim.AdamW(model_params,
                           args.lr,
                           betas=(args.beta1, args.beta2),
                           weight_decay=args.weight_decay)
    elif args.optim == 'adamod':
        return AdaMod(model_params,
                      args.lr,
                      betas=(args.beta1, args.beta2),
                      beta3=args.beta3,
                      weight_decay=args.weight_decay)
Example #3
0
class TFProcess:
    def __init__(self, cfg):
        self.cfg = cfg
        self.net = Net()
        self.root_dir = os.path.join(self.cfg["training"]["path"],
                                     self.cfg["name"])

        # Network structure
        self.RESIDUAL_FILTERS = self.cfg["model"]["filters"]
        self.RESIDUAL_BLOCKS = self.cfg["model"]["residual_blocks"]
        self.SE_ratio = self.cfg["model"]["se_ratio"]
        self.policy_channels = self.cfg["model"].get("policy_channels", 32)
        precision = self.cfg["training"].get("precision", "single")
        loss_scale = self.cfg["training"].get("loss_scale", 128)

        if precision == "single":
            self.model_dtype = tf.float32
        elif precision == "half":
            self.model_dtype = tf.float16
        else:
            raise ValueError("Unknown precision: {}".format(precision))

        # Scale the loss to prevent gradient underflow
        self.loss_scale = 1 if self.model_dtype == tf.float32 else loss_scale

        policy_head = self.cfg["model"].get("policy", "convolution")
        value_head = self.cfg["model"].get("value", "wdl")

        self.POLICY_HEAD = None
        self.VALUE_HEAD = None

        if policy_head == "classical":
            self.POLICY_HEAD = pb.NetworkFormat.POLICY_CLASSICAL
        elif policy_head == "convolution":
            self.POLICY_HEAD = pb.NetworkFormat.POLICY_CONVOLUTION
        else:
            raise ValueError(
                "Unknown policy head format: {}".format(policy_head))

        self.net.set_policyformat(self.POLICY_HEAD)

        if value_head == "classical":
            self.VALUE_HEAD = pb.NetworkFormat.VALUE_CLASSICAL
            self.wdl = False
        elif value_head == "wdl":
            self.VALUE_HEAD = pb.NetworkFormat.VALUE_WDL
            self.wdl = True
        else:
            raise ValueError(
                "Unknown value head format: {}".format(value_head))

        self.net.set_valueformat(self.VALUE_HEAD)

        self.swa_enabled = self.cfg["training"].get("swa", False)

        # Limit momentum of SWA exponential average to 1 - 1/(swa_max_n + 1)
        self.swa_max_n = self.cfg["training"].get("swa_max_n", 0)

        self.renorm_enabled = self.cfg["training"].get("renorm", False)
        self.renorm_max_r = self.cfg["training"].get("renorm_max_r", 1)
        self.renorm_max_d = self.cfg["training"].get("renorm_max_d", 0)
        self.renorm_momentum = self.cfg["training"].get(
            "renorm_momentum", 0.99)

        if self.cfg["gpu"] == "all":
            self.strategy = tf.distribute.MirroredStrategy()
            tf.distribute.experimental_set_strategy(self.strategy)
        else:
            gpus = tf.config.experimental.list_physical_devices("GPU")
            tf.config.experimental.set_visible_devices(gpus[self.cfg["gpu"]],
                                                       "GPU")
            tf.config.experimental.set_memory_growth(gpus[self.cfg["gpu"]],
                                                     True)
            self.strategy = None
        if self.model_dtype == tf.float16:
            tf.keras.mixed_precision.experimental.set_policy("mixed_float16")

        self.global_step = tf.Variable(0,
                                       name="global_step",
                                       trainable=False,
                                       dtype=tf.int64)

    def init_v2(self, train_dataset, test_dataset, validation_dataset=None):
        if self.strategy is not None:
            self.train_dataset = self.strategy.experimental_distribute_dataset(
                train_dataset)
            self.test_dataset = self.strategy.experimental_distribute_dataset(
                test_dataset)
            self.validation_dataset = self.strategy.experimental_distribute_dataset(
                validation_dataset)
        else:
            self.train_dataset = train_dataset
            self.test_dataset = test_dataset
            self.validation_dataset = validation_dataset
        self.train_iter = iter(self.train_dataset)
        self.test_iter = iter(self.test_dataset)
        self.init_net_v2()

    def init_net_v2(self):
        self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001))
        input_var = tf.keras.Input(shape=(112, 8 * 8))
        x_planes = tf.keras.layers.Reshape([112, 8, 8])(input_var)
        self.model = tf.keras.Model(inputs=input_var,
                                    outputs=self.construct_net_v2(x_planes))
        # swa_count initialized reguardless to make checkpoint code simpler.
        self.swa_count = tf.Variable(0.0, name="swa_count", trainable=False)
        self.swa_weights = None
        if self.swa_enabled:
            # Count of networks accumulated into SWA
            self.swa_weights = [
                tf.Variable(w, trainable=False) for w in self.model.weights
            ]

        self.active_lr = 0.01
        self.optimizer = AdaMod(learning_rate=lambda: self.active_lr)
        self.orig_optimizer = self.optimizer
        if self.loss_scale != 1:
            self.optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                self.optimizer, self.loss_scale)

        def correct_policy(target, output):
            output = tf.cast(output, tf.float32)
            # Calculate loss on policy head
            if self.cfg["training"].get("mask_legal_moves"):
                # extract mask for legal moves from target policy
                move_is_legal = tf.greater_equal(target, 0)
                # replace logits of illegal moves with large negative value (so that it doesn't affect policy of legal moves) without gradient
                illegal_filler = tf.zeros_like(output) - 1.0e10
                output = tf.where(move_is_legal, output, illegal_filler)
            # y_ still has -1 on illegal moves, flush them to 0
            target = tf.nn.relu(target)
            return target, output

        def policy_loss(target, output):
            target, output = correct_policy(target, output)
            policy_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                labels=tf.stop_gradient(target), logits=output)
            return tf.reduce_mean(input_tensor=policy_cross_entropy)

        self.policy_loss_fn = policy_loss

        def policy_accuracy(target, output):
            target, output = correct_policy(target, output)
            return tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(input=target, axis=1),
                             tf.argmax(input=output, axis=1)),
                    tf.float32,
                ))

        self.policy_accuracy_fn = policy_accuracy

        q_ratio = self.cfg["training"].get("q_ratio", 0)
        assert 0 <= q_ratio <= 1

        # Linear conversion to scalar to compute MSE with, for comparison to old values
        wdl = tf.expand_dims(tf.constant([1.0, 0.0, -1.0]), 1)

        self.qMix = lambda z, q: q * q_ratio + z * (1 - q_ratio)
        # Loss on value head
        if self.wdl:

            def value_loss(target, output):
                output = tf.cast(output, tf.float32)
                value_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    labels=tf.stop_gradient(target), logits=output)
                return tf.reduce_mean(input_tensor=value_cross_entropy)

            self.value_loss_fn = value_loss

            def mse_loss(target, output):
                output = tf.cast(output, tf.float32)
                scalar_z_conv = tf.matmul(tf.nn.softmax(output), wdl)
                scalar_target = tf.matmul(target, wdl)
                return tf.reduce_mean(input_tensor=tf.math.squared_difference(
                    scalar_target, scalar_z_conv))

            self.mse_loss_fn = mse_loss
        else:

            def value_loss(target, output):
                return tf.constant(0)

            self.value_loss_fn = value_loss

            def mse_loss(target, output):
                output = tf.cast(output, tf.float32)
                scalar_target = tf.matmul(target, wdl)
                return tf.reduce_mean(input_tensor=tf.math.squared_difference(
                    scalar_target, output))

            self.mse_loss_fn = mse_loss

        pol_loss_w = self.cfg["training"]["policy_loss_weight"]
        val_loss_w = self.cfg["training"]["value_loss_weight"]
        self.lossMix = lambda policy, value: pol_loss_w * policy + val_loss_w * value

        def accuracy(target, output):
            output = tf.cast(output, tf.float32)
            return tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(input=target, axis=1),
                             tf.argmax(input=output, axis=1)),
                    tf.float32,
                ))

        self.accuracy_fn = accuracy

        self.avg_policy_loss = []
        self.avg_value_loss = []
        self.avg_mse_loss = []
        self.avg_reg_term = []
        self.time_start = None
        self.last_steps = None
        # Set adaptive learning rate during training
        self.cfg["training"]["lr_boundaries"].sort()
        self.warmup_steps = self.cfg["training"].get("warmup_steps", 0)
        self.lr = self.cfg["training"]["lr_values"][0]
        self.test_writer = tf.summary.create_file_writer(
            os.path.join(os.getcwd(),
                         "leelalogs/{}-test".format(self.cfg["name"])))
        self.train_writer = tf.summary.create_file_writer(
            os.path.join(os.getcwd(),
                         "leelalogs/{}-train".format(self.cfg["name"])))
        self.validation_writer = tf.summary.create_file_writer(
            os.path.join(os.getcwd(),
                         "leelalogs/{}-validation".format(self.cfg['name'])))
        if self.swa_enabled:
            self.swa_writer = tf.summary.create_file_writer(
                os.path.join(os.getcwd(),
                             "leelalogs/{}-swa-test".format(self.cfg['name'])))
            self.swa_validation_writer = tf.summary.create_file_writer(
                os.path.join(
                    os.getcwd(),
                    "leelalogs/{}-swa-validation".format(self.cfg['name'])))
        self.checkpoint = tf.train.Checkpoint(
            optimizer=self.orig_optimizer,
            model=self.model,
            global_step=self.global_step,
            swa_count=self.swa_count,
        )
        self.checkpoint.listed = self.swa_weights
        self.manager = tf.train.CheckpointManager(
            self.checkpoint,
            directory=self.root_dir,
            max_to_keep=50,
            keep_checkpoint_every_n_hours=24,
            checkpoint_name=self.cfg['name'],
        )

    def replace_weights_v2(self, new_weights_orig):
        new_weights = [w for w in new_weights_orig]
        # self.model.weights ordering doesn't match up nicely, so first shuffle the new weights to match up.
        # input order is (for convolutional policy):
        # policy conv
        # policy bn * 4
        # policy raw conv and bias
        # value conv
        # value bn * 4
        # value dense with bias
        # value dense with bias
        #
        # output order is (for convolutional policy):
        # value conv
        # policy conv
        # value bn * 4
        # policy bn * 4
        # policy raw conv and bias
        # value dense with bias
        # value dense with bias
        new_weights[-5] = new_weights_orig[-10]
        new_weights[-6] = new_weights_orig[-11]
        new_weights[-7] = new_weights_orig[-12]
        new_weights[-8] = new_weights_orig[-13]
        new_weights[-9] = new_weights_orig[-14]
        new_weights[-10] = new_weights_orig[-15]
        new_weights[-11] = new_weights_orig[-5]
        new_weights[-12] = new_weights_orig[-6]
        new_weights[-13] = new_weights_orig[-7]
        new_weights[-14] = new_weights_orig[-8]
        new_weights[-15] = new_weights_orig[-16]
        new_weights[-16] = new_weights_orig[-9]

        all_evals = []
        offset = 0
        last_was_gamma = False
        for e, weights in enumerate(self.model.weights):
            source_idx = e + offset
            if weights.shape.ndims == 4:
                # Rescale rule50 related weights as clients do not normalize the input.
                if e == 0:
                    num_inputs = 112
                    # 50 move rule is the 110th input, or 109 starting from 0.
                    rule50_input = 109
                    for i in range(len(new_weights[source_idx])):
                        if (i % (num_inputs * 9)) // 9 == rule50_input:
                            new_weights[source_idx][
                                i] = new_weights[source_idx][i] * 99

                # Convolution weights need a transpose
                #
                # TF (kYXInputOutput)
                # [filter_height, filter_width, in_channels, out_channels]
                #
                # Leela/cuDNN/Caffe (kOutputInputYX)
                # [output, input, filter_size, filter_size]
                s = weights.shape.as_list()
                shape = [s[i] for i in [3, 2, 0, 1]]
                new_weight = tf.constant(new_weights[source_idx], shape=shape)
                weights.assign(tf.transpose(a=new_weight, perm=[2, 3, 1, 0]))
            elif weights.shape.ndims == 2:
                # Fully connected layers are [in, out] in TF
                #
                # [out, in] in Leela
                #
                s = weights.shape.as_list()
                shape = [s[i] for i in [1, 0]]
                new_weight = tf.constant(new_weights[source_idx], shape=shape)
                weights.assign(tf.transpose(a=new_weight, perm=[1, 0]))
            else:
                # Can't populate renorm weights, but the current new_weight will need using elsewhere.
                if "renorm" in weights.name:
                    offset -= 1
                    continue
                # betas without gamms need to skip the gamma in the input.
                if "beta:" in weights.name and not last_was_gamma:
                    source_idx += 1
                    offset += 1
                # Biases, batchnorm etc
                new_weight = tf.constant(new_weights[source_idx],
                                         shape=weights.shape)
                if "stddev:" in weights.name:
                    weights.assign(tf.math.sqrt(new_weight + 1e-5))
                else:
                    weights.assign(new_weight)
                # need to use the variance to also populate the stddev for renorm, so adjust offset.
                if "variance:" in weights.name and self.renorm_enabled:
                    offset -= 1
            last_was_gamma = "gamma:" in weights.name
        # Replace the SWA weights as well, ensuring swa accumulation is reset.
        if self.swa_enabled:
            self.swa_count.assign(tf.constant(0.0))
            self.update_swa_v2()
        # This should result in identical file to the starting one
        # self.save_leelaz_weights_v2('restored.pb.gz')

    def restore_v2(self):
        if self.manager.latest_checkpoint is not None:
            print("Restoring from {0}".format(self.manager.latest_checkpoint))
            self.checkpoint.restore(self.manager.latest_checkpoint)

    def process_loop_v2(self, batch_size, test_batches, batch_splits=1):
        # Get the initial steps value in case this is a resume from a step count
        # which is not a multiple of total_steps.
        steps = self.global_step.read_value()
        total_steps = self.cfg["training"]["total_steps"]
        for _ in range(steps % total_steps, total_steps):
            self.process_v2(batch_size,
                            test_batches,
                            batch_splits=batch_splits)

    @tf.function()
    def read_weights(self):
        return [w.read_value() for w in self.model.weights]

    @tf.function()
    def process_inner_loop(self, x, y, z, q):
        with tf.GradientTape() as tape:
            policy, value = self.model(x, training=True)
            policy_loss = self.policy_loss_fn(y, policy)
            reg_term = sum(self.model.losses)
            if self.wdl:
                value_loss = self.value_loss_fn(self.qMix(z, q), value)
                total_loss = self.lossMix(policy_loss, value_loss) + reg_term
            else:
                mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
                total_loss = self.lossMix(policy_loss, mse_loss) + reg_term
            if self.loss_scale != 1:
                total_loss = self.optimizer.get_scaled_loss(total_loss)
        if self.wdl:
            mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
        else:
            value_loss = self.value_loss_fn(self.qMix(z, q), value)
        return (
            policy_loss,
            value_loss,
            mse_loss,
            reg_term,
            tape.gradient(total_loss, self.model.trainable_weights),
        )

    @tf.function()
    def strategy_process_inner_loop(self, x, y, z, q):
        (
            policy_loss,
            value_loss,
            mse_loss,
            reg_term,
            new_grads,
        ) = self.strategy.experimental_run_v2(self.process_inner_loop,
                                              args=(x, y, z, q))
        policy_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                           policy_loss,
                                           axis=None)
        value_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                          value_loss,
                                          axis=None)
        mse_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                        mse_loss,
                                        axis=None)
        reg_term = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                        reg_term,
                                        axis=None)
        return policy_loss, value_loss, mse_loss, reg_term, new_grads

    def apply_grads(self, grads, batch_splits):
        if self.loss_scale != 1:
            grads = self.optimizer.get_unscaled_gradients(grads)
        max_grad_norm = (self.cfg["training"].get("max_grad_norm", 10000.0) *
                         batch_splits)
        grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
        self.optimizer.apply_gradients(zip(grads,
                                           self.model.trainable_weights))
        return grads, grad_norm

    @tf.function()
    def strategy_apply_grads(self, grads, batch_splits):
        grads, grad_norm = self.strategy.experimental_run_v2(
            self.apply_grads, args=(grads, batch_splits))
        grads = [
            self.strategy.reduce(tf.distribute.ReduceOp.MEAN, x, axis=None)
            for x in grads
        ]
        grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                         grad_norm,
                                         axis=None)
        return grads, grad_norm

    @tf.function()
    def merge_grads(self, grads, new_grads):
        return [tf.math.add(a, b) for (a, b) in zip(grads, new_grads)]

    @tf.function()
    def strategy_merge_grads(self, grads, new_grads):
        return self.strategy.experimental_run_v2(self.merge_grads,
                                                 args=(grads, new_grads))

    @tf.function()
    def add_lists(self, x, y):
        return [tf.math.add(a, b) for (a, b) in zip(x, y)]

    def process_v2(self, batch_size, test_batches, batch_splits=1):
        if not self.time_start:
            self.time_start = time.time()

        # Get the initial steps value before we do a training step.
        steps = self.global_step.read_value()
        if not self.last_steps:
            self.last_steps = steps

        if self.swa_enabled:
            # split half of test_batches between testing regular weights and SWA weights
            test_batches //= 2

        # Run test before first step to see delta since end of last run.
        if steps % self.cfg["training"]["total_steps"] == 0:
            # Steps is given as one higher than current in order to avoid it
            # being equal to the value the end of a run is stored against.
            self.calculate_test_summaries_v2(test_batches, steps + 1)
            if self.swa_enabled:
                self.calculate_swa_summaries_v2(test_batches, steps + 1)

        # Make sure that ghost batch norm can be applied
        if batch_size % 64 != 0:
            # Adjust required batch size for batch splitting.
            required_factor = 64 * self.cfg["training"].get(
                "num_batch_splits", 1)
            raise ValueError(
                "batch_size must be a multiple of {}".format(required_factor))

        # Determine learning rate
        lr_values = self.cfg["training"]["lr_values"]
        lr_boundaries = self.cfg["training"]["lr_boundaries"]
        steps_total = steps % self.cfg["training"]["total_steps"]
        self.lr = lr_values[bisect.bisect_right(lr_boundaries, steps_total)]
        if self.warmup_steps > 0 and steps < self.warmup_steps:
            self.lr = self.lr * tf.cast(steps + 1,
                                        tf.float32) / self.warmup_steps

        # need to add 1 to steps because steps will be incremented after gradient update
        if (steps +
                1) % self.cfg["training"]["train_avg_report_steps"] == 0 or (
                    steps + 1) % self.cfg["training"]["total_steps"] == 0:
            before_weights = self.read_weights()

        # Run training for this batch
        grads = None
        for _ in range(batch_splits):
            x, y, z, q = next(self.train_iter)
            if self.strategy is not None:
                (
                    policy_loss,
                    value_loss,
                    mse_loss,
                    reg_term,
                    new_grads,
                ) = self.strategy_process_inner_loop(x, y, z, q)
            else:
                (
                    policy_loss,
                    value_loss,
                    mse_loss,
                    reg_term,
                    new_grads,
                ) = self.process_inner_loop(x, y, z, q)
            if not grads:
                grads = new_grads
            else:
                if self.strategy is not None:
                    grads = self.strategy_merge_grads(grads, new_grads)
                else:
                    grads = self.merge_grads(grads, new_grads)
            # Keep running averages
            # Google's paper scales MSE by 1/4 to a [0, 1] range, so do the same to
            # get comparable values.
            mse_loss /= 4.0
            self.avg_policy_loss.append(policy_loss)
            if self.wdl:
                self.avg_value_loss.append(value_loss)
            self.avg_mse_loss.append(mse_loss)
            self.avg_reg_term.append(reg_term)
        # Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this.
        effective_batch_splits = batch_splits
        if self.strategy is not None:
            effective_batch_splits = batch_splits * self.strategy.num_replicas_in_sync
        self.active_lr = self.lr / effective_batch_splits
        if self.strategy is not None:
            grads, grad_norm = self.strategy_apply_grads(grads, batch_splits)
        else:
            grads, grad_norm = self.apply_grads(grads, batch_splits)

        # Update steps.
        self.global_step.assign_add(1)
        steps = self.global_step.read_value()

        if (steps % self.cfg["training"]["train_avg_report_steps"] == 0
                or steps % self.cfg["training"]["total_steps"] == 0):
            pol_loss_w = self.cfg["training"]["policy_loss_weight"]
            val_loss_w = self.cfg["training"]["value_loss_weight"]
            time_end = time.time()
            speed = 0
            if self.time_start:
                elapsed = time_end - self.time_start
                steps_elapsed = steps - self.last_steps
                speed = batch_size * (tf.cast(steps_elapsed, tf.float32) /
                                      elapsed)
            avg_policy_loss = np.mean(self.avg_policy_loss or [0])
            avg_value_loss = np.mean(self.avg_value_loss or [0])
            avg_mse_loss = np.mean(self.avg_mse_loss or [0])
            avg_reg_term = np.mean(self.avg_reg_term or [0])
            print(
                "step {}, lr={:g} policy={:g} value={:g} mse={:g} reg={:g} total={:g} ({:g} pos/s)"
                .format(
                    steps,
                    self.lr,
                    avg_policy_loss,
                    avg_value_loss,
                    avg_mse_loss,
                    avg_reg_term,
                    pol_loss_w * avg_policy_loss +
                    val_loss_w * avg_value_loss + avg_reg_term,
                    speed,
                ))

            after_weights = self.read_weights()
            with self.train_writer.as_default():
                tf.summary.scalar("Policy Loss", avg_policy_loss, step=steps)
                tf.summary.scalar("Value Loss", avg_value_loss, step=steps)
                tf.summary.scalar("Reg term", avg_reg_term, step=steps)
                tf.summary.scalar("LR", self.lr, step=steps)
                tf.summary.scalar("Gradient norm",
                                  grad_norm / batch_splits,
                                  step=steps)
                tf.summary.scalar("MSE Loss", avg_mse_loss, step=steps)
                self.compute_update_ratio_v2(before_weights, after_weights,
                                             steps)
            self.train_writer.flush()
            self.time_start = time_end
            self.last_steps = steps
            (
                self.avg_policy_loss,
                self.avg_value_loss,
                self.avg_mse_loss,
                self.avg_reg_term,
            ) = ([], [], [], [])

        if self.swa_enabled and steps % self.cfg["training"]["swa_steps"] == 0:
            self.update_swa_v2()

        # Calculate test values every 'test_steps', but also ensure there is
        # one at the final step so the delta to the first step can be calculted.
        if (steps % self.cfg["training"]["test_steps"] == 0
                or steps % self.cfg["training"]["total_steps"] == 0):
            self.calculate_test_summaries_v2(test_batches, steps)
            if self.swa_enabled:
                self.calculate_swa_summaries_v2(test_batches, steps)

        if self.validation_dataset is not None and (
                steps % self.cfg['training']['validation_steps'] == 0
                or steps % self.cfg['training']['total_steps'] == 0):
            if self.swa_enabled:
                self.calculate_swa_validations_v2(steps)
            else:
                self.calculate_test_validations_v2(steps)

        # Save session and weights at end, and also optionally every 'checkpoint_steps'.
        if steps % self.cfg["training"]["total_steps"] == 0 or (
                "checkpoint_steps" in self.cfg["training"]
                and steps % self.cfg["training"]["checkpoint_steps"] == 0):
            evaled_steps = steps.numpy()
            self.manager.save(checkpoint_number=evaled_steps)
            print("Model saved in file: {}".format(
                self.manager.latest_checkpoint))
            path = os.path.join(self.root_dir, self.cfg['name'])
            leela_path = path + "-" + str(evaled_steps)
            swa_path = path + "-swa-" + str(evaled_steps)
            self.net.pb.training_params.training_steps = evaled_steps
            self.save_leelaz_weights_v2(leela_path)
            if self.swa_enabled:
                self.save_swa_weights_v2(swa_path)

    @tf.function()
    def switch_to_swa(self):
        backup = self.read_weights()
        for (swa, w) in zip(self.swa_weights, self.model.weights):
            w.assign(swa.read_value())
        return backup

    @tf.function()
    def restore_weights(self, backup):
        for (old, w) in zip(backup, self.model.weights):
            w.assign(old)

    def calculate_swa_summaries_v2(self, test_batches, steps):
        backup = self.switch_to_swa()
        true_test_writer, self.test_writer = self.test_writer, self.swa_writer
        print("swa", end=" ")
        self.calculate_test_summaries_v2(test_batches, steps)
        self.test_writer = true_test_writer
        self.restore_weights(backup)

    @tf.function()
    def calculate_test_summaries_inner_loop(self, x, y, z, q):
        policy, value = self.model(x, training=False)
        policy_loss = self.policy_loss_fn(y, policy)
        policy_accuracy = self.policy_accuracy_fn(y, policy)
        if self.wdl:
            value_loss = self.value_loss_fn(self.qMix(z, q), value)
            mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
            value_accuracy = self.accuracy_fn(self.qMix(z, q), value)
        else:
            value_loss = self.value_loss_fn(self.qMix(z, q), value)
            mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
            value_accuracy = tf.constant(0.0)
        return policy_loss, value_loss, mse_loss, policy_accuracy, value_accuracy

    @tf.function()
    def strategy_calculate_test_summaries_inner_loop(self, x, y, z, q):
        (
            policy_loss,
            value_loss,
            mse_loss,
            policy_accuracy,
            value_accuracy,
        ) = self.strategy.experimental_run_v2(
            self.calculate_test_summaries_inner_loop, args=(x, y, z, q))
        policy_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                           policy_loss,
                                           axis=None)
        value_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                          value_loss,
                                          axis=None)
        mse_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                        mse_loss,
                                        axis=None)
        policy_accuracy = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                               policy_accuracy,
                                               axis=None)
        value_accuracy = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                              value_accuracy,
                                              axis=None)
        return policy_loss, value_loss, mse_loss, policy_accuracy, value_accuracy

    def calculate_test_summaries_v2(self, test_batches, steps):
        sum_policy_accuracy = 0
        sum_value_accuracy = 0
        sum_mse = 0
        sum_policy = 0
        sum_value = 0
        for _ in range(0, test_batches):
            x, y, z, q = next(self.test_iter)
            if self.strategy is not None:
                (
                    policy_loss,
                    value_loss,
                    mse_loss,
                    policy_accuracy,
                    value_accuracy,
                ) = self.strategy_calculate_test_summaries_inner_loop(
                    x, y, z, q)
            else:
                (
                    policy_loss,
                    value_loss,
                    mse_loss,
                    policy_accuracy,
                    value_accuracy,
                ) = self.calculate_test_summaries_inner_loop(x, y, z, q)
            sum_policy_accuracy += policy_accuracy
            sum_mse += mse_loss
            sum_policy += policy_loss
            if self.wdl:
                sum_value_accuracy += value_accuracy
                sum_value += value_loss
        sum_policy_accuracy /= test_batches
        sum_policy_accuracy *= 100
        sum_policy /= test_batches
        sum_value /= test_batches
        if self.wdl:
            sum_value_accuracy /= test_batches
            sum_value_accuracy *= 100
        # Additionally rescale to [0, 1] so divide by 4
        sum_mse /= 4.0 * test_batches
        self.net.pb.training_params.learning_rate = self.lr
        self.net.pb.training_params.mse_loss = sum_mse
        self.net.pb.training_params.policy_loss = sum_policy
        # TODO store value and value accuracy in pb
        self.net.pb.training_params.accuracy = sum_policy_accuracy
        with self.test_writer.as_default():
            tf.summary.scalar("Policy Loss", sum_policy, step=steps)
            tf.summary.scalar("Value Loss", sum_value, step=steps)
            tf.summary.scalar("MSE Loss", sum_mse, step=steps)
            tf.summary.scalar("Policy Accuracy",
                              sum_policy_accuracy,
                              step=steps)
            if self.wdl:
                tf.summary.scalar("Value Accuracy",
                                  sum_value_accuracy,
                                  step=steps)
            for w in self.model.weights:
                tf.summary.histogram(w.name, w, step=steps)
        self.test_writer.flush()

        print(
            "step {}, policy={:g} value={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g}"
            .format(
                steps,
                sum_policy,
                sum_value,
                sum_policy_accuracy,
                sum_value_accuracy,
                sum_mse,
            ))

    def calculate_swa_validations_v2(self, steps):
        backup = self.read_weights()
        for (swa, w) in zip(self.swa_weights, self.model.weights):
            w.assign(swa.read_value())
        true_validation_writer, self.validation_writer = self.validation_writer, self.swa_validation_writer
        print('swa', end=' ')
        self.calculate_test_validations_v2(steps)
        self.validation_writer = true_validation_writer
        for (old, w) in zip(backup, self.model.weights):
            w.assign(old)

    def calculate_test_validations_v2(self, steps):
        sum_policy_accuracy = 0
        sum_value_accuracy = 0
        sum_mse = 0
        sum_policy = 0
        sum_value = 0
        counter = 0
        for (x, y, z, q) in self.validation_dataset:
            policy_loss, value_loss, mse_loss, policy_accuracy, value_accuracy = self.calculate_test_summaries_inner_loop(
                x, y, z, q)
            sum_policy_accuracy += policy_accuracy
            sum_mse += mse_loss
            sum_policy += policy_loss
            counter += 1
            if self.wdl:
                sum_value_accuracy += value_accuracy
                sum_value += value_loss
        sum_policy_accuracy /= counter
        sum_policy_accuracy *= 100
        sum_policy /= counter
        sum_value /= counter
        if self.wdl:
            sum_value_accuracy /= counter
            sum_value_accuracy *= 100
        # Additionally rescale to [0, 1] so divide by 4
        sum_mse /= (4.0 * counter)
        with self.validation_writer.as_default():
            tf.summary.scalar("Policy Loss", sum_policy, step=steps)
            tf.summary.scalar("Value Loss", sum_value, step=steps)
            tf.summary.scalar("MSE Loss", sum_mse, step=steps)
            tf.summary.scalar("Policy Accuracy",
                              sum_policy_accuracy,
                              step=steps)
            if self.wdl:
                tf.summary.scalar("Value Accuracy",
                                  sum_value_accuracy,
                                  step=steps)
        self.validation_writer.flush()

        print("step {}, validation: policy={:g} value={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g}".\
            format(steps, sum_policy, sum_value, sum_policy_accuracy, sum_value_accuracy, sum_mse))

    @tf.function()
    def compute_update_ratio_v2(self, before_weights, after_weights, steps):
        """Compute the ratio of gradient norm to weight norm.

        Adapted from https://github.com/tensorflow/minigo/blob/c923cd5b11f7d417c9541ad61414bf175a84dc31/dual_net.py#L567
        """
        deltas = [
            after - before
            for after, before in zip(after_weights, before_weights)
        ]
        delta_norms = [tf.math.reduce_euclidean_norm(d) for d in deltas]
        weight_norms = [
            tf.math.reduce_euclidean_norm(w) for w in before_weights
        ]
        ratios = [(tensor.name, tf.cond(w != 0.0, lambda: d / w, lambda: -1.0))
                  for d, w, tensor in zip(delta_norms, weight_norms,
                                          self.model.weights)
                  if not "moving" in tensor.name]
        for name, ratio in ratios:
            tf.summary.scalar("update_ratios/" + name, ratio, step=steps)
        # Filtering is hard, so just push infinities/NaNs to an unreasonably large value.
        ratios = [
            tf.cond(r > 0, lambda: tf.math.log(r) / 2.30258509299,
                    lambda: 200.0) for (_, r) in ratios
        ]
        tf.summary.histogram("update_ratios_log10",
                             tf.stack(ratios),
                             buckets=1000,
                             step=steps)

    @tf.function()
    def update_swa_v2(self):
        num = self.swa_count.read_value()
        for (w, swa) in zip(self.model.weights, self.swa_weights):
            swa.assign(swa.read_value() * (num / (num + 1.)) + w.read_value() *
                       (1. / (num + 1.)))
        self.swa_count.assign(tf.math.minimum(num + 1., self.swa_max_n))

    def save_swa_weights_v2(self, filename):
        backup = self.switch_to_swa()
        self.save_leelaz_weights_v2(filename)
        self.restore_weights(backup)

    def save_leelaz_weights_v2(self, filename):
        all_tensors = []
        all_weights = []
        last_was_gamma = False
        for weights in self.model.weights:
            work_weights = None
            if weights.shape.ndims == 4:
                # Convolution weights need a transpose
                #
                # TF (kYXInputOutput)
                # [filter_height, filter_width, in_channels, out_channels]
                #
                # Leela/cuDNN/Caffe (kOutputInputYX)
                # [output, input, filter_size, filter_size]
                work_weights = tf.transpose(a=weights, perm=[3, 2, 0, 1])
            elif weights.shape.ndims == 2:
                # Fully connected layers are [in, out] in TF
                #
                # [out, in] in Leela
                #
                work_weights = tf.transpose(a=weights, perm=[1, 0])
            else:
                # batch renorm has extra weights, but we don't know what to do with them.
                if "renorm" in weights.name:
                    continue
                # renorm has variance, but it is not the primary source of truth
                if "variance:" in weights.name and self.renorm_enabled:
                    continue
                # Renorm has moving stddev not variance, undo the transform to make it compatible.
                if "stddev:" in weights.name:
                    all_tensors.append(tf.math.square(weights) - 1e-5)
                    continue
                # Biases, batchnorm etc
                # pb expects every batch norm to have gammas, but not all of our
                # batch norms have gammas, so manually add pretend gammas.
                if "beta:" in weights.name and not last_was_gamma:
                    all_tensors.append(tf.ones_like(weights))
                work_weights = weights.read_value()
            all_tensors.append(work_weights)
            last_was_gamma = "gamma:" in weights.name

        # HACK: model weights ordering is some kind of breadth first traversal,
        # but pb expects a specific ordering which BFT is not a match for once
        # we get to the heads. Apply manual permutation.
        # This is fragile and at minimum should have some checks to ensure it isn't breaking things.
        # TODO: also support classic policy head as it has a different set of layers and hence changes the permutation.
        permuted_tensors = [w for w in all_tensors]
        permuted_tensors[-5] = all_tensors[-11]
        permuted_tensors[-6] = all_tensors[-12]
        permuted_tensors[-7] = all_tensors[-13]
        permuted_tensors[-8] = all_tensors[-14]
        permuted_tensors[-9] = all_tensors[-16]
        permuted_tensors[-10] = all_tensors[-5]
        permuted_tensors[-11] = all_tensors[-6]
        permuted_tensors[-12] = all_tensors[-7]
        permuted_tensors[-13] = all_tensors[-8]
        permuted_tensors[-14] = all_tensors[-9]
        permuted_tensors[-15] = all_tensors[-10]
        permuted_tensors[-16] = all_tensors[-15]
        all_tensors = permuted_tensors

        for e, nparray in enumerate(all_tensors):
            # Rescale rule50 related weights as clients do not normalize the input.
            if e == 0:
                num_inputs = 112
                # 50 move rule is the 110th input, or 109 starting from 0.
                rule50_input = 109
                wt_flt = []
                for i, weight in enumerate(np.ravel(nparray)):
                    if (i % (num_inputs * 9)) // 9 == rule50_input:
                        wt_flt.append(weight / 99)
                    else:
                        wt_flt.append(weight)
            else:
                wt_flt = [wt for wt in np.ravel(nparray)]
            all_weights.append(wt_flt)

        self.net.fill_net(all_weights)
        self.net.save_proto(filename)

    def batch_norm_v2(self, input, scale=False):
        if self.renorm_enabled:
            clipping = {
                "rmin": 1.0 / self.renorm_max_r,
                "rmax": self.renorm_max_r,
                "dmax": self.renorm_max_d,
            }
            return tf.keras.layers.BatchNormalization(
                epsilon=1e-5,
                axis=1,
                fused=False,
                center=True,
                scale=scale,
                renorm=True,
                renorm_clipping=clipping,
                renorm_momentum=self.renorm_momentum,
            )(input)
        else:
            return tf.keras.layers.BatchNormalization(
                epsilon=1e-5,
                axis=1,
                fused=False,
                center=True,
                scale=scale,
                virtual_batch_size=64,
            )(input)

    def squeeze_excitation_v2(self, inputs, channels):
        assert channels % self.SE_ratio == 0
        pooled = tf.keras.layers.GlobalAveragePooling2D(
            data_format="channels_first")(inputs)
        squeezed = tf.keras.layers.Activation("relu")(tf.keras.layers.Dense(
            channels // self.SE_ratio,
            kernel_initializer="glorot_normal",
            kernel_regularizer=self.l2reg,
        )(pooled))
        excited = tf.keras.layers.Dense(
            2 * channels,
            kernel_initializer="glorot_normal",
            kernel_regularizer=self.l2reg,
        )(squeezed)
        return ApplySqueezeExcitation()([inputs, excited])

    def conv_block_v2(self,
                      inputs,
                      filter_size,
                      output_channels,
                      bn_scale=False):
        conv = tf.keras.layers.Conv2D(
            output_channels,
            filter_size,
            use_bias=False,
            padding="same",
            kernel_initializer="glorot_normal",
            kernel_regularizer=self.l2reg,
            data_format="channels_first",
        )(inputs)
        return tf.keras.layers.Activation("relu")(self.batch_norm_v2(
            conv, scale=bn_scale))

    def residual_block_v2(self, inputs, channels):
        conv1 = tf.keras.layers.Conv2D(
            channels,
            3,
            use_bias=False,
            padding="same",
            kernel_initializer="glorot_normal",
            kernel_regularizer=self.l2reg,
            data_format="channels_first",
        )(inputs)
        out1 = tf.keras.layers.Activation("relu")(self.batch_norm_v2(
            conv1, scale=False))
        conv2 = tf.keras.layers.Conv2D(
            channels,
            3,
            use_bias=False,
            padding="same",
            kernel_initializer="glorot_normal",
            kernel_regularizer=self.l2reg,
            data_format="channels_first",
        )(out1)
        out2 = self.squeeze_excitation_v2(
            self.batch_norm_v2(conv2, scale=True), channels)
        return tf.keras.layers.Activation("relu")(tf.keras.layers.add(
            [inputs, out2]))

    def construct_net_v2(self, inputs):
        flow = self.conv_block_v2(inputs,
                                  filter_size=3,
                                  output_channels=self.RESIDUAL_FILTERS,
                                  bn_scale=True)
        for _ in range(0, self.RESIDUAL_BLOCKS):
            flow = self.residual_block_v2(flow, self.RESIDUAL_FILTERS)
        # Policy head
        if self.POLICY_HEAD == pb.NetworkFormat.POLICY_CONVOLUTION:
            conv_pol = self.conv_block_v2(
                flow, filter_size=3, output_channels=self.RESIDUAL_FILTERS)
            conv_pol2 = tf.keras.layers.Conv2D(
                80,
                3,
                use_bias=True,
                padding="same",
                kernel_initializer="glorot_normal",
                kernel_regularizer=self.l2reg,
                bias_regularizer=self.l2reg,
                data_format="channels_first",
            )(conv_pol)
            h_fc1 = ApplyPolicyMap()(conv_pol2)
        elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_CLASSICAL:
            conv_pol = self.conv_block_v2(flow,
                                          filter_size=1,
                                          output_channels=self.policy_channels)
            h_conv_pol_flat = tf.keras.layers.Flatten()(conv_pol)
            h_fc1 = tf.keras.layers.Dense(
                1858,
                kernel_initializer="glorot_normal",
                kernel_regularizer=self.l2reg,
                bias_regularizer=self.l2reg,
            )(h_conv_pol_flat)
        else:
            raise ValueError("Unknown policy head type {}".format(
                self.POLICY_HEAD))

        # Value head
        conv_val = self.conv_block_v2(flow, filter_size=1, output_channels=32)
        h_conv_val_flat = tf.keras.layers.Flatten()(conv_val)
        h_fc2 = tf.keras.layers.Dense(
            128,
            kernel_initializer="glorot_normal",
            kernel_regularizer=self.l2reg,
            activation="relu",
        )(h_conv_val_flat)
        if self.wdl:
            h_fc3 = tf.keras.layers.Dense(
                3,
                kernel_initializer="glorot_normal",
                kernel_regularizer=self.l2reg,
                bias_regularizer=self.l2reg,
            )(h_fc2)
        else:
            h_fc3 = tf.keras.layers.Dense(
                1,
                kernel_initializer="glorot_normal",
                kernel_regularizer=self.l2reg,
                activation="tanh",
            )(h_fc2)
        return h_fc1, h_fc3
Example #4
0
    def init_net_v2(self):
        self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001))
        input_var = tf.keras.Input(shape=(112, 8 * 8))
        x_planes = tf.keras.layers.Reshape([112, 8, 8])(input_var)
        self.model = tf.keras.Model(inputs=input_var,
                                    outputs=self.construct_net_v2(x_planes))
        # swa_count initialized reguardless to make checkpoint code simpler.
        self.swa_count = tf.Variable(0.0, name="swa_count", trainable=False)
        self.swa_weights = None
        if self.swa_enabled:
            # Count of networks accumulated into SWA
            self.swa_weights = [
                tf.Variable(w, trainable=False) for w in self.model.weights
            ]

        self.active_lr = 0.01
        self.optimizer = AdaMod(learning_rate=lambda: self.active_lr)
        self.orig_optimizer = self.optimizer
        if self.loss_scale != 1:
            self.optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
                self.optimizer, self.loss_scale)

        def correct_policy(target, output):
            output = tf.cast(output, tf.float32)
            # Calculate loss on policy head
            if self.cfg["training"].get("mask_legal_moves"):
                # extract mask for legal moves from target policy
                move_is_legal = tf.greater_equal(target, 0)
                # replace logits of illegal moves with large negative value (so that it doesn't affect policy of legal moves) without gradient
                illegal_filler = tf.zeros_like(output) - 1.0e10
                output = tf.where(move_is_legal, output, illegal_filler)
            # y_ still has -1 on illegal moves, flush them to 0
            target = tf.nn.relu(target)
            return target, output

        def policy_loss(target, output):
            target, output = correct_policy(target, output)
            policy_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                labels=tf.stop_gradient(target), logits=output)
            return tf.reduce_mean(input_tensor=policy_cross_entropy)

        self.policy_loss_fn = policy_loss

        def policy_accuracy(target, output):
            target, output = correct_policy(target, output)
            return tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(input=target, axis=1),
                             tf.argmax(input=output, axis=1)),
                    tf.float32,
                ))

        self.policy_accuracy_fn = policy_accuracy

        q_ratio = self.cfg["training"].get("q_ratio", 0)
        assert 0 <= q_ratio <= 1

        # Linear conversion to scalar to compute MSE with, for comparison to old values
        wdl = tf.expand_dims(tf.constant([1.0, 0.0, -1.0]), 1)

        self.qMix = lambda z, q: q * q_ratio + z * (1 - q_ratio)
        # Loss on value head
        if self.wdl:

            def value_loss(target, output):
                output = tf.cast(output, tf.float32)
                value_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
                    labels=tf.stop_gradient(target), logits=output)
                return tf.reduce_mean(input_tensor=value_cross_entropy)

            self.value_loss_fn = value_loss

            def mse_loss(target, output):
                output = tf.cast(output, tf.float32)
                scalar_z_conv = tf.matmul(tf.nn.softmax(output), wdl)
                scalar_target = tf.matmul(target, wdl)
                return tf.reduce_mean(input_tensor=tf.math.squared_difference(
                    scalar_target, scalar_z_conv))

            self.mse_loss_fn = mse_loss
        else:

            def value_loss(target, output):
                return tf.constant(0)

            self.value_loss_fn = value_loss

            def mse_loss(target, output):
                output = tf.cast(output, tf.float32)
                scalar_target = tf.matmul(target, wdl)
                return tf.reduce_mean(input_tensor=tf.math.squared_difference(
                    scalar_target, output))

            self.mse_loss_fn = mse_loss

        pol_loss_w = self.cfg["training"]["policy_loss_weight"]
        val_loss_w = self.cfg["training"]["value_loss_weight"]
        self.lossMix = lambda policy, value: pol_loss_w * policy + val_loss_w * value

        def accuracy(target, output):
            output = tf.cast(output, tf.float32)
            return tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(input=target, axis=1),
                             tf.argmax(input=output, axis=1)),
                    tf.float32,
                ))

        self.accuracy_fn = accuracy

        self.avg_policy_loss = []
        self.avg_value_loss = []
        self.avg_mse_loss = []
        self.avg_reg_term = []
        self.time_start = None
        self.last_steps = None
        # Set adaptive learning rate during training
        self.cfg["training"]["lr_boundaries"].sort()
        self.warmup_steps = self.cfg["training"].get("warmup_steps", 0)
        self.lr = self.cfg["training"]["lr_values"][0]
        self.test_writer = tf.summary.create_file_writer(
            os.path.join(os.getcwd(),
                         "leelalogs/{}-test".format(self.cfg["name"])))
        self.train_writer = tf.summary.create_file_writer(
            os.path.join(os.getcwd(),
                         "leelalogs/{}-train".format(self.cfg["name"])))
        self.validation_writer = tf.summary.create_file_writer(
            os.path.join(os.getcwd(),
                         "leelalogs/{}-validation".format(self.cfg['name'])))
        if self.swa_enabled:
            self.swa_writer = tf.summary.create_file_writer(
                os.path.join(os.getcwd(),
                             "leelalogs/{}-swa-test".format(self.cfg['name'])))
            self.swa_validation_writer = tf.summary.create_file_writer(
                os.path.join(
                    os.getcwd(),
                    "leelalogs/{}-swa-validation".format(self.cfg['name'])))
        self.checkpoint = tf.train.Checkpoint(
            optimizer=self.orig_optimizer,
            model=self.model,
            global_step=self.global_step,
            swa_count=self.swa_count,
        )
        self.checkpoint.listed = self.swa_weights
        self.manager = tf.train.CheckpointManager(
            self.checkpoint,
            directory=self.root_dir,
            max_to_keep=50,
            keep_checkpoint_every_n_hours=24,
            checkpoint_name=self.cfg['name'],
        )
Example #5
0
model = model.to(device)
model = torch.nn.DataParallel(model)
fc = Full_layer(num_feature=num_feature, num_classes=len(classes)).to(device)
fc = torch.nn.DataParallel(fc)

# train from start
best_top1 = 0
start_epoch = 0

# criterion and optimizer
isda_criterion = ISDALoss(num_feature, num_classes)
ce_criterion = CrossEntropyLoss(smooth_eps=0.1).to(device)

if args.adamod:
    print("\n Using AdaMod optimizer")
    optimizer = AdaMod([{'params': model.parameters()},
                        {'params': fc.parameters()}], lr=args.learning_rate, weight_decay=args.weight_decay)
    

if args.deepmemory:
    print("\n Using DeepMemory optimizer")
    optimizer = Lookahead(DeepMemory([{'params': model.parameters()},
                                      {'params': fc.parameters()}], lr=args.learning_rate, weight_decay=args.weight_decay))
if args.adalook:
    print("\n Using AdaMod+LookAhead optimizer")
    optimizer = Lookahead(AdaMod([{'params': model.parameters()},
                                  {'params': fc.parameters()}], lr=args.learning_rate, weight_decay=args.weight_decay))
                          
if args.ranger:   
    print("\n Using RAdam+LookAhead optimizer")
    optimizer = Lookahead(RAdam([{'params': model.parameters()},
                                 {'params': fc.parameters()}], lr=args.learning_rate, weight_decay=args.weight_decay))
model = EfficientNet.from_pretrained('efficientnet-b4', num_classes=10)
model = model.to(device)
model = torch.nn.DataParallel(model)

# train from start
best_top1 = 0
start_epoch = 0

# criterion and optimizer
params = [p for p in model.parameters()]
criterion = CrossEntropyLoss(smooth_eps=0.1).to(device)

if args.adamod:
    print("\n Using AdaMod optimizer")
    optimizer = AdaMod(params,
                       lr=args.learning_rate,
                       weight_decay=args.weight_decay)

if args.deepmemory:
    print("\n Using DeepMemory optimizer")
    optimizer = Lookahead(
        DeepMemory(params,
                   lr=args.learning_rate,
                   weight_decay=args.weight_decay))

if args.adalook:
    print("\n Using AdaMod+LookAhead optimizer")
    optimizer = Lookahead(
        AdaMod(params, lr=args.learning_rate, weight_decay=args.weight_decay))

if args.ranger: