""" print(isinstance(net.loss_var.owner_opr, WeightDecay)) print(net.loss_var.owner_opr._params) print(type(net.loss_var.owner_opr._param_weights)) exit() """ train_func = env.make_func_from_loss_var(net.loss_var, "train", train_state=True) valid_func = env.make_func_from_loss_var(net.loss_var, "val", train_state=False) lr = 0.1 optimizer = Momentum(lr, 0.9) #optimizer.learning_rate = 0.01 optimizer(train_func) train_func.comp_graph.share_device_memory_with(valid_func.comp_graph) dic = { "loss": net.loss_var, "pre_loss": preloss, "outputs": net.outputs[0] } train_func.compile(dic) valid_func.compile(dic) env.register_checkpoint_component("network", net) env.register_checkpoint_component("opt_state",
net.loss_var = WeightDecay(net.loss_var, { "*conv*:W": 1e-4, "*fc*:W": 1e-4, "*bnaff*:k": 1e-4, "*offset*": 1e-4 }) train_func = env.make_func_from_loss_var(net.loss_var, "train", train_state=True) valid_func = env.make_func_from_loss_var(net.loss_var, "val", train_state=False) lr = 0.1 optimizer = Momentum(lr, 0.9) #optimizer.learning_rate = 0.01 optimizer(train_func) train_func.comp_graph.share_device_memory_with(valid_func.comp_graph) dic = { "loss": net.loss_var, "pre_loss": preloss, "outputs": net.outputs[0] } train_func.compile(dic) valid_func.compile(dic) env.register_checkpoint_component("network", net) env.register_checkpoint_component("opt_state",
parser = argparse.ArgumentParser() os.system("rm -r tbdata/") tb = TB("tbdata/") with TrainingEnv(name = "lyy.{}.test".format(net_name), part_count = 2, custom_parser = parser) as env: args = parser.parse_args() num_GPU = len(args.devices.split(',')) minibatch_size *= num_GPU net, SS_list = make_network(minibatch_size = minibatch_size) preloss = net.loss_var net.loss_var = WeightDecay(net.loss_var, {"*conv*": 1e-4, "*fc*": 1e-4, "*bnaff*:k": 1e-4, "*offset*":1e-4}) train_func = env.make_func_from_loss_var(net.loss_var, "train", train_state = True) lr = 0.1 * num_GPU optimizer = Momentum(lr, 0.9) optimizer(train_func) #train_func.comp_graph.share_device_memory_with(valid_func.comp_graph) dic = { "loss": net.loss_var, "pre_loss": preloss, "outputs": net.outputs[0] } train_func.compile(dic) valid_func = Function().compile(net.outputs[0]) env.register_checkpoint_component("network", net) env.register_checkpoint_component("opt_state", train_func.optimizer_state)