]> wimlib.net Git - wimlib/blobdiff - src/win32_common.c
WIMBoot: Update WimOverlay.dat directly when WOF not running
[wimlib] / src / win32_common.c
index e5d6a386f26fd3dfae8595ef3837bfcea580f15b..67d208d23256c5daa347dfb0cb59cf7c65890f4b 100644 (file)
 
 #include <errno.h>
 
-#include "wimlib/win32_common.h"
+#ifdef WITH_NTDLL
+#  include <winternl.h>
+#endif
 
+#include "wimlib/win32_common.h"
 #include "wimlib/assert.h"
 #include "wimlib/error.h"
 #include "wimlib/util.h"
 
-#ifdef ENABLE_ERROR_MESSAGES
-void
-win32_error(DWORD err_code)
-{
-       wchar_t *buffer;
-       DWORD nchars;
-       nchars = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM |
-                                   FORMAT_MESSAGE_ALLOCATE_BUFFER,
-                               NULL, err_code, 0,
-                               (wchar_t*)&buffer, 0, NULL);
-       if (nchars == 0) {
-               ERROR("Error printing error message! "
-                     "Computer will self-destruct in 3 seconds.");
-       } else {
-               ERROR("Win32 error: %ls", buffer);
-               LocalFree(buffer);
-       }
-}
-#endif /* ENABLE_ERROR_MESSAGES */
-
-int
+static int
 win32_error_to_errno(DWORD err_code)
 {
        /* This mapping is that used in Cygwin.
@@ -324,11 +307,26 @@ win32_error_to_errno(DWORD err_code)
        }
 }
 
+
+void
+set_errno_from_win32_error(DWORD err)
+{
+       errno = win32_error_to_errno(err);
+}
+
 void
 set_errno_from_GetLastError(void)
 {
-       errno = win32_error_to_errno(GetLastError());
+       set_errno_from_win32_error(GetLastError());
+}
+
+#ifdef WITH_NTDLL
+void
+set_errno_from_nt_status(NTSTATUS status)
+{
+       set_errno_from_win32_error((*func_RtlNtStatusToDosError)(status));
 }
+#endif
 
 /* Given a Windows-style path, return the number of characters of the prefix
  * that specify the path to the root directory of a drive, or return 0 if the
@@ -363,6 +361,12 @@ bool
 win32_path_is_root_of_drive(const wchar_t *path)
 {
        size_t drive_spec_len;
+       wchar_t full_path[32768];
+       DWORD ret;
+
+       ret = GetFullPathName(path, ARRAY_LEN(full_path), full_path, NULL);
+       if (ret > 0 && ret < ARRAY_LEN(full_path))
+               path = full_path;
 
        /* Explicit drive letter and path separator? */
        drive_spec_len = win32_path_drive_spec_len(path);
@@ -374,23 +378,23 @@ win32_path_is_root_of_drive(const wchar_t *path)
                if (!is_any_path_separator(*p))
                        return false;
        return true;
-
-       /* XXX This function does not handle paths like "c:" where the working
-        * directory on "c:" is actually "c:\", or weird paths like "\.".  But
-        * currently the capture and apply code always prefixes the paths with
-        * \\?\ anyway so this is irrelevant... */
 }
 
 
 /* Given a path, which may not yet exist, get a set of flags that describe the
  * features of the volume the path is on. */
 int
-win32_get_vol_flags(const wchar_t *path, unsigned *vol_flags_ret)
+win32_get_vol_flags(const wchar_t *path, unsigned *vol_flags_ret,
+                   bool *supports_SetFileShortName_ret)
 {
        wchar_t *volume;
        BOOL bret;
        DWORD vol_flags;
        size_t drive_spec_len;
+       wchar_t filesystem_name[MAX_PATH + 1];
+
+       if (supports_SetFileShortName_ret)
+               *supports_SetFileShortName_ret = false;
 
        drive_spec_len = win32_path_drive_spec_len(path);
 
@@ -412,43 +416,100 @@ win32_get_vol_flags(const wchar_t *path, unsigned *vol_flags_ret)
                volume[drive_spec_len] = L'\\';
                volume[drive_spec_len + 1] = L'\0';
        }
-       bret = GetVolumeInformationW(volume, /* lpRootPathName */
-                                    NULL,  /* lpVolumeNameBuffer */
-                                    0,     /* nVolumeNameSize */
-                                    NULL,  /* lpVolumeSerialNumber */
-                                    NULL,  /* lpMaximumComponentLength */
-                                    &vol_flags, /* lpFileSystemFlags */
-                                    NULL,  /* lpFileSystemNameBuffer */
-                                    0);    /* nFileSystemNameSize */
+       bret = GetVolumeInformation(
+                       volume,                         /* lpRootPathName */
+                       NULL,                           /* lpVolumeNameBuffer */
+                       0,                              /* nVolumeNameSize */
+                       NULL,                           /* lpVolumeSerialNumber */
+                       NULL,                           /* lpMaximumComponentLength */
+                       &vol_flags,                     /* lpFileSystemFlags */
+                       filesystem_name,                /* lpFileSystemNameBuffer */
+                       ARRAY_LEN(filesystem_name));    /* nFileSystemNameSize */
        if (!bret) {
-               DWORD err = GetLastError();
-               WARNING("Failed to get volume information for path \"%ls\"", path);
-               win32_error(err);
+               set_errno_from_GetLastError();
+               WARNING_WITH_ERRNO("Failed to get volume information for "
+                                  "path \"%ls\"", path);
                vol_flags = 0xffffffff;
+               goto out;
+       }
+
+       if (wcsstr(filesystem_name, L"NTFS")) {
+               /* FILE_SUPPORTS_HARD_LINKS is only supported on Windows 7 and later.
+                * Force it on anyway if filesystem is NTFS.  */
+               vol_flags |= FILE_SUPPORTS_HARD_LINKS;
+
+               if (supports_SetFileShortName_ret)
+                       *supports_SetFileShortName_ret = true;
        }
 
+out:
        DEBUG("using vol_flags = %x", vol_flags);
        *vol_flags_ret = vol_flags;
        return 0;
 }
 
-HANDLE
-win32_open_existing_file(const wchar_t *path, DWORD dwDesiredAccess)
+static bool
+win32_modify_privilege(const wchar_t *privilege, bool enable)
+{
+       HANDLE hToken;
+       LUID luid;
+       TOKEN_PRIVILEGES newState;
+       bool ret = FALSE;
+
+       if (!OpenProcessToken(GetCurrentProcess(),
+                             TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
+                             &hToken))
+               goto out;
+
+       if (!LookupPrivilegeValue(NULL, privilege, &luid))
+               goto out_close_handle;
+
+       newState.PrivilegeCount = 1;
+       newState.Privileges[0].Luid = luid;
+       newState.Privileges[0].Attributes = (enable ? SE_PRIVILEGE_ENABLED : 0);
+       SetLastError(ERROR_SUCCESS);
+       ret = AdjustTokenPrivileges(hToken, FALSE, &newState, 0, NULL, NULL);
+       if (ret && GetLastError() == ERROR_NOT_ALL_ASSIGNED)
+               ret = FALSE;
+out_close_handle:
+       CloseHandle(hToken);
+out:
+       return ret;
+}
+
+static bool
+win32_modify_capture_privileges(bool enable)
+{
+       return win32_modify_privilege(SE_BACKUP_NAME, enable)
+           && win32_modify_privilege(SE_SECURITY_NAME, enable);
+}
+
+static bool
+win32_modify_apply_privileges(bool enable)
+{
+       return win32_modify_privilege(SE_RESTORE_NAME, enable)
+           && win32_modify_privilege(SE_SECURITY_NAME, enable)
+           && win32_modify_privilege(SE_TAKE_OWNERSHIP_NAME, enable);
+}
+
+static void
+win32_release_capture_and_apply_privileges(void)
 {
-       return CreateFileW(path,
-                          dwDesiredAccess,
-                          FILE_SHARE_READ,
-                          NULL, /* lpSecurityAttributes */
-                          OPEN_EXISTING,
-                          FILE_FLAG_BACKUP_SEMANTICS |
-                              FILE_FLAG_OPEN_REPARSE_POINT,
-                          NULL /* hTemplateFile */);
+       win32_modify_capture_privileges(false);
+       win32_modify_apply_privileges(false);
 }
 
 HANDLE
-win32_open_file_data_only(const wchar_t *path)
+win32_open_existing_file(const wchar_t *path, DWORD dwDesiredAccess)
 {
-       return win32_open_existing_file(path, FILE_READ_DATA);
+       return CreateFile(path,
+                         dwDesiredAccess,
+                         FILE_SHARE_READ,
+                         NULL, /* lpSecurityAttributes */
+                         OPEN_EXISTING,
+                         FILE_FLAG_BACKUP_SEMANTICS |
+                               FILE_FLAG_OPEN_REPARSE_POINT,
+                         NULL /* hTemplateFile */);
 }
 
 /* Pointers to functions that are not available on all targetted versions of
@@ -465,12 +526,60 @@ HANDLE (WINAPI *win32func_FindFirstStreamW)(LPCWSTR lpFileName,
 BOOL (WINAPI *win32func_FindNextStreamW)(HANDLE hFindStream,
                                         LPVOID lpFindStreamData) = NULL;
 
+/* Vista and later */
+BOOL (WINAPI *win32func_CreateSymbolicLinkW)(const wchar_t *lpSymlinkFileName,
+                                            const wchar_t *lpTargetFileName,
+                                            DWORD dwFlags) = NULL;
+
+#ifdef WITH_NTDLL
+
+DWORD (WINAPI *func_RtlNtStatusToDosError)(NTSTATUS status);
+
+NTSTATUS (WINAPI *func_NtQueryInformationFile)(HANDLE FileHandle,
+                                              PIO_STATUS_BLOCK IoStatusBlock,
+                                              PVOID FileInformation,
+                                              ULONG Length,
+                                              FILE_INFORMATION_CLASS FileInformationClass);
+
+NTSTATUS (WINAPI *func_NtQuerySecurityObject)(HANDLE handle,
+                                             SECURITY_INFORMATION SecurityInformation,
+                                             PSECURITY_DESCRIPTOR SecurityDescriptor,
+                                             ULONG Length,
+                                             PULONG LengthNeeded);
+
+NTSTATUS (WINAPI *func_NtQueryDirectoryFile) (HANDLE FileHandle,
+                                             HANDLE Event,
+                                             PIO_APC_ROUTINE ApcRoutine,
+                                             PVOID ApcContext,
+                                             PIO_STATUS_BLOCK IoStatusBlock,
+                                             PVOID FileInformation,
+                                             ULONG Length,
+                                             FILE_INFORMATION_CLASS FileInformationClass,
+                                             BOOLEAN ReturnSingleEntry,
+                                             PUNICODE_STRING FileName,
+                                             BOOLEAN RestartScan);
+
+NTSTATUS (WINAPI *func_NtSetSecurityObject)(HANDLE Handle,
+                                           SECURITY_INFORMATION SecurityInformation,
+                                           PSECURITY_DESCRIPTOR SecurityDescriptor);
+
+NTSTATUS (WINAPI *func_RtlCreateSystemVolumeInformationFolder)
+               (PCUNICODE_STRING VolumeRootPath);
+
+#endif /* WITH_NTDLL */
+
 static OSVERSIONINFO windows_version_info = {
        .dwOSVersionInfoSize = sizeof(OSVERSIONINFO),
 };
 
 static HMODULE hKernel32 = NULL;
 
+#ifdef WITH_NTDLL
+static HMODULE hNtdll = NULL;
+#endif
+
+static bool acquired_privileges = false;
+
 bool
 windows_version_is_at_least(unsigned major, unsigned minor)
 {
@@ -479,22 +588,28 @@ windows_version_is_at_least(unsigned major, unsigned minor)
                 windows_version_info.dwMinorVersion >= minor);
 }
 
-/* Try to dynamically load some functions */
-void
-win32_global_init(void)
+/* One-time initialization for Windows capture/apply code.  */
+int
+win32_global_init(int init_flags)
 {
-       DWORD err;
-
-       if (hKernel32 == NULL) {
-               DEBUG("Loading Kernel32.dll");
-               hKernel32 = LoadLibraryW(L"Kernel32.dll");
-               if (hKernel32 == NULL) {
-                       err = GetLastError();
-                       WARNING("Can't load Kernel32.dll");
-                       win32_error(err);
-               }
+       /* Try to acquire useful privileges.  */
+       if (!(init_flags & WIMLIB_INIT_FLAG_DONT_ACQUIRE_PRIVILEGES)) {
+               if (!win32_modify_capture_privileges(true))
+                       if (init_flags & WIMLIB_INIT_FLAG_STRICT_CAPTURE_PRIVILEGES)
+                               goto insufficient_privileges;
+               if (!win32_modify_apply_privileges(true))
+                       if (init_flags & WIMLIB_INIT_FLAG_STRICT_APPLY_PRIVILEGES)
+                               goto insufficient_privileges;
+               acquired_privileges = true;
        }
 
+       /* Get Windows version information.  */
+       GetVersionEx(&windows_version_info);
+
+       /* Try to dynamically load some functions.  */
+       if (hKernel32 == NULL)
+               hKernel32 = LoadLibrary(L"Kernel32.dll");
+
        if (hKernel32) {
                win32func_FindFirstStreamW = (void*)GetProcAddress(hKernel32,
                                                                   "FindFirstStreamW");
@@ -504,19 +619,69 @@ win32_global_init(void)
                        if (!win32func_FindNextStreamW)
                                win32func_FindFirstStreamW = NULL;
                }
+               win32func_CreateSymbolicLinkW = (void*)GetProcAddress(hKernel32,
+                                                                     "CreateSymbolicLinkW");
        }
 
-       GetVersionEx(&windows_version_info);
+#ifdef WITH_NTDLL
+       if (hNtdll == NULL)
+               hNtdll = LoadLibrary(L"ntdll.dll");
+
+       if (hNtdll) {
+               func_RtlNtStatusToDosError  =
+                       (void*)GetProcAddress(hNtdll, "RtlNtStatusToDosError");
+               if (func_RtlNtStatusToDosError) {
+
+                       func_NtQuerySecurityObject  =
+                               (void*)GetProcAddress(hNtdll, "NtQuerySecurityObject");
+
+                       func_NtQueryDirectoryFile   =
+                               (void*)GetProcAddress(hNtdll, "NtQueryDirectoryFile");
+
+                       func_NtQueryInformationFile =
+                               (void*)GetProcAddress(hNtdll, "NtQueryInformationFile");
+
+                       func_NtSetSecurityObject    =
+                               (void*)GetProcAddress(hNtdll, "NtSetSecurityObject");
+                       func_RtlCreateSystemVolumeInformationFolder =
+                               (void*)GetProcAddress(hNtdll, "RtlCreateSystemVolumeInformationFolder");
+               }
+       }
+
+       DEBUG("FindFirstStreamW       @ %p", win32func_FindFirstStreamW);
+       DEBUG("FindNextStreamW        @ %p", win32func_FindNextStreamW);
+       DEBUG("CreateSymbolicLinkW    @ %p", win32func_CreateSymbolicLinkW);
+       DEBUG("RtlNtStatusToDosError  @ %p", func_RtlNtStatusToDosError);
+       DEBUG("NtQuerySecurityObject  @ %p", func_NtQuerySecurityObject);
+       DEBUG("NtQueryDirectoryFile   @ %p", func_NtQueryDirectoryFile);
+       DEBUG("NtQueryInformationFile @ %p", func_NtQueryInformationFile);
+       DEBUG("NtSetSecurityObject    @ %p", func_NtSetSecurityObject);
+       DEBUG("RtlCreateSystemVolumeInformationFolder    @ %p",
+             func_RtlCreateSystemVolumeInformationFolder);
+#endif
+
+       return 0;
+
+insufficient_privileges:
+       win32_release_capture_and_apply_privileges();
+       return WIMLIB_ERR_INSUFFICIENT_PRIVILEGES;
 }
 
 void
 win32_global_cleanup(void)
 {
+       if (acquired_privileges)
+               win32_release_capture_and_apply_privileges();
        if (hKernel32 != NULL) {
-               DEBUG("Closing Kernel32.dll");
                FreeLibrary(hKernel32);
                hKernel32 = NULL;
        }
+#ifdef WITH_NTDLL
+       if (hNtdll != NULL) {
+               FreeLibrary(hNtdll);
+               hNtdll = NULL;
+       }
+#endif
 }
 
 #endif /* __WIN32__ */