Beispiel #1
0
    def learn(self):
        self.learning_steps += 1

        states, actions, rewards, next_states, dones =\
            self.memory.sample(self.batch_size)

        # one-hot encoding
        states = torch.eye(self.env.nrow * self.env.ncol,
                           dtype=torch.float32)[states].to(self.device)

        next_states = torch.eye(self.env.nrow * self.env.ncol,
                                dtype=torch.float32)[next_states].to(
                                    self.device)

        # Calculate features of states.
        state_embeddings = self.online_net.calculate_state_embeddings(states)

        quantile_loss, mean_q = self.calculate_loss(state_embeddings, actions,
                                                    rewards, next_states,
                                                    dones)

        update_params(self.optim,
                      quantile_loss,
                      networks=[self.online_net],
                      retain_graph=False,
                      grad_cliping=self.grad_cliping)

        if 4 * self.steps % self.log_interval == 0:
            self.writer.add_scalar('loss/quantile_loss',
                                   quantile_loss.detach().item(),
                                   4 * self.steps)
            self.writer.add_scalar('stats/mean_Q', mean_q, 4 * self.steps)
    def learn_latent(self):
        images_seq, actions_seq, rewards_seq, dones_seq = \
            self.memory.sample_latent(self.latent_batch_size)
        latent_loss = self.calc_latent_loss(images_seq, actions_seq,
                                            rewards_seq, dones_seq)
        update_params(self.latent_optim, self.latent, latent_loss,
                      self.grad_clip)

        if self.learning_steps % self.learning_log_interval == 0:
            self.writer.add_scalar('loss/latent',
                                   latent_loss.detach().item(),
                                   self.learning_steps)
Beispiel #3
0
    def learn(self):
        self.learning_steps += 1
        if self.learning_steps % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        if self.per:
            # batch with indices and priority weights
            batch, indices, weights = \
                self.memory.sample(self.batch_size)
        else:
            batch = self.memory.sample(self.batch_size)
            # set priority weights to 1 when we don't use PER.
            weights = 1.

        q1_loss, q2_loss, errors, mean_q1, mean_q2 =\
            self.calc_critic_loss(batch, weights)
        policy_loss, entropies = self.calc_policy_loss(batch, weights)

        update_params(self.q1_optim, self.critic.Q1, q1_loss, self.grad_clip)
        update_params(self.q2_optim, self.critic.Q2, q2_loss, self.grad_clip)
        update_params(self.policy_optim, self.policy, policy_loss,
                      self.grad_clip)

        if self.entropy_tuning:
            entropy_loss = self.calc_entropy_loss(entropies, weights)
            update_params(self.alpha_optim, None, entropy_loss)
            self.alpha = self.log_alpha.exp()

        if self.per:
            # update priority weights
            self.memory.update_priority(indices, errors.cpu().numpy())
Beispiel #4
0
    def learn_sac(self):
        images_seq, actions_seq, rewards =\
            self.memory.sample_sac(self.batch_size)

        # NOTE: Don't update the encoder part of the policy here.
        with torch.no_grad():
            # f(1:t+1)
            features_seq = self.latent.encoder(images_seq)
            latent_samples, _ = self.latent.sample_posterior(
                features_seq, actions_seq)

        # z(t), z(t+1)
        latents_seq = torch.cat(latent_samples, dim=-1)
        latents = latents_seq[:, -2]
        next_latents = latents_seq[:, -1]
        # a(t)
        actions = actions_seq[:, -1]
        # fa(t)=(x(1:t), a(1:t-1)), fa(t+1)=(x(2:t+1), a(2:t))
        feature_actions, next_feature_actions =\
            create_feature_actions(features_seq, actions_seq)

        q1_loss, q2_loss = self.calc_critic_loss(latents, next_latents,
                                                 actions, next_feature_actions,
                                                 rewards)
        policy_loss, entropies = self.calc_policy_loss(latents,
                                                       feature_actions)

        update_params(self.q1_optim, self.critic.Q1, q1_loss, self.grad_clip)
        update_params(self.q2_optim, self.critic.Q2, q2_loss, self.grad_clip)
        update_params(self.policy_optim, self.policy, policy_loss,
                      self.grad_clip)

        if self.entropy_tuning:
            entropy_loss = self.calc_entropy_loss(entropies)
            update_params(self.alpha_optim, None, entropy_loss)
            self.alpha = self.log_alpha.exp()
        else:
            entropy_loss = 0.

        if self.learning_steps % self.learning_log_interval == 0:
            self.writer.add_scalar('loss/Q1',
                                   q1_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/Q2',
                                   q2_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/policy',
                                   policy_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/alpha',
                                   entropy_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('stats/alpha',
                                   self.alpha.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('stats/entropy',
                                   entropies.detach().mean().item(),
                                   self.learning_steps)
    def learn_latent(self):
        # Sample sequence
        images_seq, actions_seq, skill_seq, dones_seq = \
            self.memory.sample_latent(self.latent_batch_size)

        # Calc loss
        latent_loss = self.calc_latent_loss(images_seq, actions_seq, skill_seq,
                                            dones_seq)

        # Backprop
        update_params(self.latent_optim, self.latent, latent_loss,
                      self.grad_clip)

        # Write net params
        if self._is_log(self.learning_log_interval * 5):
            self.latent.write_net_params(self.writer, self.learning_steps)
Beispiel #6
0
    def bandwidth(self, subid, params=None):
        ''' /v1/server/bandwidth
        GET - account
        Get the bandwidth used by a virtual machine

        Link: https://www.vultr.com/api/#server_bandwidth
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/bandwidth', params, 'GET')
Beispiel #7
0
    def records(self, domain, params=None):
        ''' /v1/dns/records
        GET - account
        List all the records associated with a particular domain

        Link: https://www.vultr.com/api/#dns_records
        '''
        params = update_params(params, {'domain': domain})
        return self.request('/v1/dns/records', params, 'GET')
    def update(self, scriptid, params=None):
        """ /v1/startupscript/update
        POST - account
        Update an existing startup script

        Link: https://www.vultr.com/api/#startupscript_update
        """
        params = update_params(params, {"SCRIPTID": scriptid})
        return self.request("/v1/startupscript/update", params, "POST")
    def destroy(self, scriptid, params=None):
        """ /v1/startupscript/destroy
        POST - account
        Remove a startup script

        Link: https://www.vultr.com/api/#startupscript_destroy
        """
        params = update_params(params, {"SCRIPTID": scriptid})
        return self.request("/v1/startupscript/destroy", params, "POST")
    def create(self, name, script, params=None):
        """ /v1/startupscript/create
        POST - account
        Create a startup script

        Link: https://www.vultr.com/api/#startupscript_create
        """
        params = update_params(params, {"name": name, "script": script})
        return self.request("/v1/startupscript/create", params, "POST")
Beispiel #11
0
    def get_user_data(self, subid, params=None):
        ''' /v1/server/get_user_data
        GET - account
        Retrieves the (base64 encoded) user-data for this subscription.

        Link: https://www.vultr.com/api/#server_get_user_data
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/get_user_data', params, 'GET')
Beispiel #12
0
    def delete_domain(self, domain, params=None):
        ''' /v1/dns/delete_domain
        POST - account
        Delete a domain name (and all associated records)

        Link: https://www.vultr.com/api/#dns_delete_domain
        '''
        params = update_params(params, {'domain': domain})
        return self.request('/v1/dns/delete_domain', params, 'POST')
Beispiel #13
0
    def list(self, subid, params=None):
        ''' /v1/server/list_ipv4
        GET - account
        List the IPv4 information of a virtual machine. IP information is only
        available for virtual machines in the "active" state.

        Link: https://www.vultr.com/api/#server_list_ipv4
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/list_ipv4', params, 'GET')
Beispiel #14
0
    def destroy(self, sshkeyid, params=None):
        ''' /v1/sshkey/destroy
        POST - account
        Remove a SSH key. Note that this will not remove
        the key from any machines that already have it.

        Link: https://www.vultr.com/api/#sshkey_destroy
        '''
        params = update_params(params, {'SSHKEYID': sshkeyid})
        return self.request('/v1/sshkey/destroy', params, 'POST')
Beispiel #15
0
    def create(self, subid, params=None):
        ''' /v1/snapshot/create
        POST - account
        Create a snapshot from an existing virtual machine.
        The virtual machine does not need to be stopped.

        Link: https://www.vultr.com/api/#snapshot_create
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/snapshot/create', params, 'POST')
Beispiel #16
0
    def start(self, subid, params=None):
        ''' /v1/server/start
        POST - account
        Start a virtual machine. If the machine is already
        running, it will be restarted.

        Link: https://www.vultr.com/api/#server_start
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/start', params, 'POST')
Beispiel #17
0
    def destroy(self, snapshotid, params=None):
        ''' /v1/snapshot/destroy
        POST - account
        Destroy (delete) a snapshot. There is no going
        back from this call.

        Link: https://www.vultr.com/api/#snapshot_destroy
        '''
        params = update_params(params, {'SNAPSHOTID': snapshotid})
        return self.request('/v1/snapshot/destroy', params, 'POST')
Beispiel #18
0
    def neighbors(self, subid, params=None):
        ''' v1/server/neighbors
        GET - account
        Determine what other subscriptions are hosted on the same physical
        host as a given subscription.

        Link: https://www.vultr.com/api/#server_neighbors
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/neighbors', params, 'GET')
Beispiel #19
0
    def reboot(self, subid, params=None):
        ''' /v1/server/reboot
        POST - account
        Reboot a virtual machine. This is a hard reboot
        (basically, unplugging the machine).

        Link: https://www.vultr.com/api/#server_reboot
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/reboot', params, 'POST')
Beispiel #20
0
    def os_change_list(self, subid, params=None):
        ''' /v1/server/os_change_list
        GET - account
        Retrieves a list of operating systems to which this server can be
        changed.

        Link: https://www.vultr.com/api/#server_os_change_list
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/os_change_list', params, 'GET')
Beispiel #21
0
    def upgrade_plan_list(self, subid, params=None):
        ''' /v1/server/upgrade_plan_list
        GET - account
        Retrieve a list of the VPSPLANIDs for which a virtual machine
        can be upgraded. An empty response array means that there are
        currently no upgrades available.

        Link: https://www.vultr.com/api/#server_upgrade_plan_list
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/upgrade_plan_list', params, 'GET')
Beispiel #22
0
    def destroy(self, subid, params=None):
        ''' /v1/server/destroy
        POST - account
        Destroy (delete) a virtual machine. All data will be permanently lost,
        and the IP address will be released. There is no going back from this
        call.

        Link: https://www.vultr.com/api/#server_destroy
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/destroy', params, 'POST')
Beispiel #23
0
    def reinstall(self, subid, params=None):
        ''' /v1/server/reinstall
        POST - account
        Reinstall the operating system on a virtual machine. All data
        will be permanently lost, but the IP address will remain the
        same There is no going back from this call.

        Link: https://www.vultr.com/api/#server_reinstall
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/reinstall', params, 'POST')
Beispiel #24
0
    def update(self, sshkeyid, params=None):
        ''' /v1/sshkey/update
        POST - account
        Update an existing SSH Key. Note that this will only
        update newly installed machines. The key will not be
        updated on any existing machines.

        Link: https://www.vultr.com/api/#sshkey_update
        '''
        params = update_params(params, {'SSHKEYID': sshkeyid})
        return self.request('/v1/sshkey/update', params, 'POST')
Beispiel #25
0
    def list_ipv6(self, subid, params=None):
        ''' /v1/server/list_ipv6
        GET - account
        List the IPv6 information of a virtual machine. IP information is only
        available for virtual machines in the "active" state. If the virtual
        machine does not have IPv6 enabled, then an empty array is returned.

        Link: https://www.vultr.com/api/#server_list_ipv6
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/list_ipv6', params, 'GET')
Beispiel #26
0
    def create(self, name, ssh_key, params=None):
        ''' /v1/sshkey/create
        POST - account
        Create a new SSH Key

        Link: https://www.vultr.com/api/#sshkey_create
        '''
        params = update_params(params, {
            'name': name,
            'ssh_key': ssh_key
        })
        return self.request('/v1/sshkey/create', params, 'POST')
Beispiel #27
0
    def create(self, subid, params=None):
        ''' /v1/server/create_ipv4
        POST - account
        Add a new IPv4 address to a server. You will start being billed for
        this immediately. The server will be rebooted unless you specify
        otherwise. You must reboot the server before the IPv4 address can be
        configured.

        Link: https://www.vultr.com/api/#server_create_ipv4
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/create_ipv4', params, 'POST')
Beispiel #28
0
    def update_record(self, domain, recordid, params=None):
        ''' /v1/dns/update_record
        POST - account
        Update a DNS record

        Link: https://www.vultr.com/api/#dns_update_record
        '''
        params = update_params(params, {
            'domain': domain,
            'RECORDID': recordid
        })
        return self.request('/v1/dns/update_record', params, 'POST')
Beispiel #29
0
    def create_domain(self, domain, ipaddr, params=None):
        ''' /v1/dns/create_domain
        POST - account
        Create a domain name in DNS

        Link: https://www.vultr.com/api/#dns_create_domain
        '''
        params = update_params(params, {
            'domain': domain,
            'ip': ipaddr
        })
        return self.request('/v1/dns/create_domain', params, 'POST')
Beispiel #30
0
    def availability(self, dcid, params=None):
        ''' /v1/regions/availability
        GET - public
        Retrieve a list of the VPSPLANIDs currently available
        in this location. If your account has special plans available,
        you will need to pass your api_key in in order to see them.
        For all other accounts, the API key is not optional.

        Link: https://www.vultr.com/api/#regions_region_available
        '''
        params = update_params(params, {'DCID': dcid})
        return self.request('/v1/regions/availability', params, 'GET')
Beispiel #31
0
    def label_set(self, subid, label, params=None):
        ''' /v1/server/label_set
        POST - account
        Set the label of a virtual machine.

        Link: https://www.vultr.com/api/#server_label_set
        '''
        params = update_params(params, {
            'SUBID': subid,
            'label': label
        })
        return self.request('/v1/server/label_set', params, 'POST')
Beispiel #32
0
    def halt(self, subid, params=None):
        ''' /v1/server/halt
        POST - account
        Halt a virtual machine. This is a hard power off (basically, unplugging
        the machine). The data on the machine will not be modified, and you
        will still be billed for the machine. To completely delete a
        machine, see v1/server/destroy

        Link: https://www.vultr.com/api/#server_halt
        '''
        params = update_params(params, {'SUBID': subid})
        return self.request('/v1/server/halt', params, 'POST')
Beispiel #33
0
    def restore_backup(self, subid, backupid, params=None):
        ''' /v1/server/restore_backup
        POST - account
        Restore the specified backup to the virtual machine. Any data
        already on the virtual machine will be lost.

        Link: https://www.vultr.com/api/#server_restore_backup
        '''
        params = update_params(params, {
            'SUBID': subid,
            'BACKUPID': backupid
        })
        return self.request('/v1/server/restore_backup', params, 'POST')
Beispiel #34
0
    def restore_snapshot(self, subid, snapshotid, params=None):
        ''' /v1/server/restore_snapshot
        POST - account
        Restore the specificed snapshot to the virtual machine.
        Any data already on the virtual machine will be lost.

        Link: https://www.vultr.com/api/#server_restore_snapshot
        '''
        params = update_params(params, {
            'SUBID': subid,
            'SNAPSHOTID': snapshotid
        })
        return self.request('/v1/server/restore_snapshot', params, 'POST')
Beispiel #35
0
    def upgrade_plan(self, subid, vpsplanid, params=None):
        ''' /v1/server/upgrade_plan
        POST - account
        Upgrade the plan of a virtual machine. The virtual machine will be
        rebooted upon a successful upgrade.

        Link: https://www.vultr.com/api/#server_upgrade_plan
        '''
        params = update_params(params, {
            'SUBID': subid,
            'VPSPLANID': vpsplanid
        })
        return self.request('/v1/server/upgrade_plan', params, 'POST')
    def learn(self):
        self.learning_steps += 1
        if self.learning_steps % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)

        if self.per:
            # batch with indices and priority weights
            batch, indices, weights = \
                self.memory.sample(self.batch_size)
        else:
            batch = self.memory.sample(self.batch_size)
            # set priority weights to 1 when we don't use PER.
            weights = 1.

        rand = random.randint(0, len(PREF) - 1)
        PREF_SET = []
        preference = self.get_pref()
        preference = torch.tensor(preference, device=self.device)
        PREF_SET.append(preference)
        for _ in range(self.set_num - 1):
            p = self.get_pref()
            p = torch.tensor(p, device=self.device)
            PREF_SET.append(p)
        '''
        PREF_SET = PREF#####testing
        preference = random.choice(PREF)
        preference = torch.tensor(preference ,device = self.device)
        '''

        q1_loss, q2_loss, errors, mean_q1, mean_q2 =\
            self.calc_critic_loss(batch, weights, preference, PREF_SET)

        policy_loss, entropies = self.calc_policy_loss(batch, weights,
                                                       preference, PREF_SET)

        update_params(self.q1_optim, self.critic.Q1, q1_loss, self.grad_clip)
        update_params(self.q2_optim, self.critic.Q2, q2_loss, self.grad_clip)
        update_params(self.policy_optim, self.policy, policy_loss,
                      self.grad_clip)

        if self.entropy_tuning:
            entropy_loss = self.calc_entropy_loss(entropies, weights)
            update_params(self.alpha_optim, None, entropy_loss)
            self.alpha = self.log_alpha.exp()
            self.writer.add_scalar('loss/alpha',
                                   entropy_loss.detach().item(), self.steps)
        if self.per:
            # update priority weights
            self.memory.update_priority(indices, errors.cpu().numpy())

        if self.learning_steps % self.log_interval == 0:
            self.writer.add_scalar('loss/Q1',
                                   q1_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/Q2',
                                   q2_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('loss/policy',
                                   policy_loss.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('stats/alpha',
                                   self.alpha.detach().item(),
                                   self.learning_steps)
            self.writer.add_scalar('stats/mean_Q1', mean_q1,
                                   self.learning_steps)
            self.writer.add_scalar('stats/mean_Q2', mean_q2,
                                   self.learning_steps)
            self.writer.add_scalar('stats/entropy',
                                   entropies.detach().mean().item(),
                                   self.learning_steps)
Beispiel #37
0
def main_run_disptach(pypsa_net,
                      load,
                      gen_constraints={
                          'p_max_pu': None,
                          'p_min_pu': None
                      },
                      params={}):

    # Update gen constrains dict with
    # values passed by the users and params
    gen_constraints = update_gen_constrains(gen_constraints)
    params = update_params(load.shape[0], params)

    print('Preprocessing input data..')
    # Preprocess input data:
    #   - Add date range as index
    #   - Check whether gen constraints has same lenght as load
    load_, gen_constraints_ = preprocess_input_data(load, gen_constraints,
                                                    params)
    tot_snap = load_.index

    print('Adapting PyPSA grid with parameters..')
    # Preprocess net parameters:
    #   - Change ramps according to params step_opf_min (assuming original
    #     values are normalizing for every 5 minutes)
    #   - It checks for all gen units if commitable variables is False
    #     (commitable as False helps to create a LP problem for PyPSA)
    pypsa_net = preprocess_net(pypsa_net, params['step_opf_min'])

    months = tot_snap.month.unique()
    start = time.time()
    results = []
    for month in months:
        # Get snapshots per month
        snap_per_month = tot_snap[tot_snap.month == month]
        # Filter input data per month
        load_per_month = load_.loc[snap_per_month]
        # Get gen constraints separated and filter by month
        g_max_pu, g_min_pu = gen_constraints_['p_max_pu'], gen_constraints_[
            'p_min_pu']
        g_max_pu_per_month = g_max_pu.loc[snap_per_month]
        g_min_pu_per_month = g_min_pu.loc[snap_per_month]
        # Get grouped snapsshots given monthly snapshots
        snap_per_mode = get_grouped_snapshots(snap_per_month,
                                              params['mode_opf'])
        for snaps in snap_per_mode:
            # Truncate input data per mode (day, week, month)
            load_per_mode = load_per_month.loc[snaps]
            gen_max_pu_per_mode = g_max_pu_per_month.loc[snaps]
            gen_min_pu_per_mode = g_min_pu_per_month.loc[snaps]
            # Run opf given in specified mode
            results.append(
                run_opf(
                    pypsa_net,
                    load_per_mode,
                    gen_max_pu_per_mode,
                    gen_min_pu_per_mode,
                    params,
                ))

    # Unpack individual dispatchs
    opf_prod = pd.DataFrame()
    for df in results:
        opf_prod = pd.concat([opf_prod, df], axis=0)

    # Sort by datetime
    opf_prod.sort_index(inplace=True)
    # Create complete prod_p dataframe and interpolate missing rows
    prod_p = opf_prod.copy()
    # Apply interpolation in case of step_opf_min greater than 5 min
    if params['step_opf_min'] > 5:
        print('\n => Interpolating dispatch to have 5 minutes resolution..')
        prod_p = interpolate_dispatch(prod_p)
    # Add noise to results
    gen_cap = pypsa_net.generators.p_nom
    prod_p_with_noise = add_noise_gen(prod_p, gen_cap, noise_factor=0.001)

    end = time.time()
    print('Total time {} min'.format(round((end - start) / 60, 2)))
    print('OPF Done......')

    return prod_p_with_noise
Beispiel #38
0
    lr = args.lr
    if epoch >= 400:
        lr = args.lr * 0.01
    elif epoch >= 200:
        lr = args.lr * 0.1
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


## Prepare for Training
transforms = tf.load_transforms(args.transform)
trainset = tf.load_trainset(args.data, transforms, path=args.data_dir)
#trainset = tf.corrupt_labels(trainset, args.lcr, args.lcs)
if args.pretrain_dir is not None:
    net, _ = tf.load_checkpoint(args.pretrain_dir, args.pretrain_epo)
    utils.update_params(model_dir, args.pretrain_dir)
else:
    net = tf.load_architectures_ce(args.arch, trainset.num_classes)
assert (trainset.num_classes %
        args.cpb == 0), "Number of classes not divisible by cpb"
classes = np.unique(trainset.targets)
class_batch_num = trainset.num_classes // args.cpb
class_batch_list = classes.reshape(class_batch_num, args.cpb)

#trainloader = DataLoader(trainset, batch_size=args.bs, drop_last=True, num_workers=4)
criterion = nn.CrossEntropyLoss()
optimizer = SGD(net.parameters(),
                lr=args.lr,
                momentum=args.mom,
                weight_decay=args.wd)
Beispiel #39
0
            """decrease the learning rate"""
            lr = args.lr
            if epoch >= 400:
                lr = args.lr * 0.01
            elif epoch >= 200:
                lr = args.lr * 0.1
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr


        ## Prepare for Training
        if args.pretrain_dir is not None:
            pretrain_model_dir = os.path.join(args.pretrain_dir,
                                'sup_expert_resnet18+128_{}_epo200_bs1000_lr0.001_mom0.9_wd0.0005_gam11.0_gam21.0_eps0.5_lcr0.0'.format(source_name))
            net, _ = tf.load_checkpoint(pretrain_model_dir, args.pretrain_epo)
            utils.update_params(model_dir, pretrain_model_dir)
        else:
            net = tf.load_architectures(args.arch, args.fd)
        transforms = tf.load_transforms(args.transform)
        trainset = tf.load_trainset(ds_name, transforms, path=args.data_dir)
        print("Number of classes in {} is: {}".format(ds_name,trainset.num_classes))
        trainset = tf.corrupt_labels(trainset, args.lcr, args.lcs)
        trainloader = DataLoader(trainset, batch_size=args.bs, drop_last=True, num_workers=4)
        criterion = MaximalCodingRateReduction(gam1=args.gam1, gam2=args.gam2, eps=args.eps)
        optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.mom, weight_decay=args.wd)


        ## Training
        for epoch in range(args.epo):
            lr_schedule(epoch, optimizer)
            for step, (batch_imgs, batch_lbls) in enumerate(trainloader):
Beispiel #40
0
        else:
            break


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        '--config',
                        type=str,
                        help='JSON file for configuration')
    parser.add_argument('-p', '--params', nargs='+', default=[])
    args = parser.parse_args()
    args.rank = 0

    # Parse configs.  Globals nicer in this case
    with open(args.config) as f:
        data = f.read()

    config = json.loads(data)
    update_params(config, args.params)

    train_config = config["train_config"]
    model_config = config["model_config"]
    data_config = config["data_config"]

    print = Printer()

    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    main(**train_config, model_config=model_config, data_config=data_config)
Beispiel #41
0
        [.8, .9],
        [.2, .1],
        [.3, .2],
        [.4, .3],
        [.5, .4],
        [.6, .5],
        [.7, .6],
        [.8, .7],
    ])

    hps = {
        'lr': .01,  # <-- learning rate
    }

    params = build_params(
        inputs.shape[1],  # <-- num features
        outcome_variables.shape[1])

    num_epochs = 100

    print('loss initially: ',
          loss(params, inputs=inputs, targets=outcome_variables, hps=hps))
    for epoch in range(num_epochs):
        gradients = loss_grad(params,
                              inputs=inputs,
                              targets=outcome_variables,
                              hps=hps)
        params = utils.update_params(params, gradients, hps['lr'])
    print('loss after training: ',
          loss(params, inputs=inputs, targets=outcome_variables, hps=hps))